xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/group_assignment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/group_assignment.h"
17 
18 #include <cstdint>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/str_util.h"
33 #include "tensorflow/core/platform/strcat.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/dtensor/cc/dstatus.h"
36 
37 namespace tensorflow {
38 namespace dtensor {
39 
40 GroupAssignment::ReplicaToDeviceMap
DefaultReplicaToDeviceMap(int num_slices,int slice_size)41 GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(int num_slices,
42                                                                int slice_size) {
43   absl::flat_hash_map<ReplicaId, DeviceId> map;
44   for (int i = 0; i < num_slices; ++i) {
45     for (int j = 0; j < slice_size; ++j) {
46       map[ReplicaId{i * slice_size + j}] = DeviceId{i, j};
47     }
48   }
49   return ReplicaToDeviceMap(std::move(map));
50 }
51 
ReplicaToDeviceMap(absl::flat_hash_map<ReplicaId,DeviceId> map)52 GroupAssignment::ReplicaToDeviceMap::ReplicaToDeviceMap(
53     absl::flat_hash_map<ReplicaId, DeviceId> map)
54     : map_(std::move(map)) {
55   std::set<int> slice_ids;
56   for (const auto& entry : map_) {
57     slice_ids.insert(entry.second.slice_id);
58   }
59   CHECK_GT(slice_ids.size(), 0);                // Crash OK
60   CHECK_EQ(map_.size() % slice_ids.size(), 0);  // Crash OK
61   num_slices_ = slice_ids.size();
62 }
63 
ReplicaGroups(std::vector<std::vector<int>> replica_ids)64 GroupAssignment::ReplicaGroups::ReplicaGroups(
65     std::vector<std::vector<int>> replica_ids)
66     : replica_ids_(std::move(replica_ids)) {
67   int n = replica_ids_.size();
68   CHECK_GT(n, 0);  // Crash OK
69   int g = replica_ids_.front().size();
70   CHECK_GT(g, 0);  // Crash OK
71   std::set<int> seen_replica_ids;
72   for (std::vector<int>& group : replica_ids_) {
73     CHECK_EQ(group.size(), g);  // Crash OK
74     for (int replica_id : group) {
75       CHECK_GE(replica_id, 0);  // Crash OK
76       bool inserted = seen_replica_ids.insert(replica_id).second;
77       CHECK(inserted);  // Crash OK
78     }
79   }
80 }
81 
ToMLIR(mlir::MLIRContext & context) const82 mlir::DenseIntElementsAttr GroupAssignment::ReplicaGroups::ToMLIR(
83     mlir::MLIRContext& context) const {
84   auto shaped_type = mlir::RankedTensorType::get(
85       {num_groups(), group_size()}, mlir::IntegerType::get(&context, 32));
86 
87   llvm::SmallVector<int32, 4> flat_replica_ids;
88   flat_replica_ids.reserve(num_replica_ids());
89   for (const std::vector<int>& group : replica_ids()) {
90     flat_replica_ids.insert(flat_replica_ids.end(), group.begin(), group.end());
91   }
92 
93   return mlir::DenseIntElementsAttr::get(shaped_type, flat_replica_ids);
94 }
95 
ToString() const96 std::string GroupAssignment::ReplicaGroups::ToString() const {
97   return strings::StrCat(
98       "[",
99       str_util::Join(replica_ids(), ", ",
100                      [](std::string* str, const std::vector<int>& group) {
101                        strings::StrAppend(str, "[", str_util::Join(group, ", "),
102                                           "]");
103                      }),
104       "]");
105 }
106 
FromMLIR(const mlir::DenseIntElementsAttr & group_assignment_attr,ReplicaToDeviceMap replica_to_device_map)107 StatusOr<GroupAssignment> GroupAssignment::FromMLIR(
108     const mlir::DenseIntElementsAttr& group_assignment_attr,
109     ReplicaToDeviceMap replica_to_device_map) {
110   mlir::ShapedType shaped_type = group_assignment_attr.getType();
111   if (!shaped_type.hasRank()) {
112     return errors::InvalidArgument("group_assignment_attr must have a rank");
113   }
114   if (shaped_type.getRank() != 2) {
115     return errors::InvalidArgument(
116         "group_assignment_attr must have a rank of 2, got ",
117         shaped_type.getRank());
118   }
119   llvm::ArrayRef<int64_t> shape = shaped_type.getShape();
120   int num_groups = shape[0];
121   if (num_groups <= 0) {
122     return errors::InvalidArgument(
123         "group_assignment_attr must have at least 1 group, got ", num_groups);
124   }
125   int group_size = shape[1];
126   if (group_size <= 0) {
127     return errors::InvalidArgument(
128         "group_assignment_attr must have non-empty groups, got ", group_size,
129         " replica IDs per group");
130   }
131   int num_replica_ids = num_groups * group_size;
132   if (num_replica_ids != replica_to_device_map.num_cores()) {
133     return errors::InvalidArgument("group_assignment_attr must have ",
134                                    replica_to_device_map.num_cores(),
135                                    " replica IDs, got ", num_replica_ids);
136   }
137 
138   // Translate the flat group assignment to a 2D array.
139   std::vector<std::vector<int>> replica_ids;
140   replica_ids.resize(num_groups, std::vector<int>(group_size));
141   std::set<int> seen_replica_ids;
142   if (group_assignment_attr.getNumElements() != num_replica_ids) {
143     return errors::InvalidArgument(
144         "group_assignments_attr num elements was not equal to the number of "
145         "replica ids.");
146   }
147   for (const auto& it :
148        llvm::enumerate(group_assignment_attr.getValues<llvm::APInt>())) {
149     int index = it.index();
150     int replica_id = it.value().getSExtValue();
151 
152     // If all replica IDs are within this range and distinct, they must be a
153     // permutation of [0, ..., num_replica_ids).
154     if (replica_id < 0 || replica_id >= num_replica_ids) {
155       return errors::InvalidArgument("Out of range replica ID: ", replica_id);
156     }
157     if (!seen_replica_ids.insert(replica_id).second) {
158       return errors::InvalidArgument(
159           "All replica IDs in group_assigment must be distinct, seeing ",
160           replica_id, " more than once");
161     }
162 
163     replica_ids[index / group_size][index % group_size] = replica_id;
164   }
165 
166   GroupAssignment group_assignment(
167       /*global=*/ReplicaGroups(std::move(replica_ids)),
168       std::move(replica_to_device_map));
169   TF_RETURN_IF_ERROR(group_assignment.GlobalToSlices());
170   return group_assignment;
171 }
172 
ToString() const173 std::string GroupAssignment::ToString() const {
174   return strings::StrCat(
175       "GroupAssignment global: ", global_.ToString(), "; hosts: ",
176       hosts_.empty()
177           ? "<none>"
178           : str_util::Join(hosts_, ", ",
179                            [](std::string* str, const ReplicaGroups& groups) {
180                              strings::StrAppend(str, groups.ToString());
181                            }),
182       "; slices: ",
183       slices_.empty()
184           ? "<none>"
185           : str_util::Join(slices_, ", ",
186                            [](std::string* str, const ReplicaGroups& groups) {
187                              strings::StrAppend(str, groups.ToString());
188                            }));
189 }
190 
IsWithinSlices() const191 bool GroupAssignment::IsWithinSlices() const {
192   // This function returns true iff no group in the global view gets split in
193   // `GlobalToSlices`, i.e., the total group count remains the same.
194   int total_num_groups = 0;
195   for (int i = 0; i < num_slices(); i++) {
196     total_num_groups += num_groups(i).ValueOrDie();
197   }
198   if (total_num_groups != num_groups()) return false;
199   return total_num_groups == num_groups();
200 }
201 
GlobalToSlices()202 Status GroupAssignment::GlobalToSlices() {
203   VLOG(2) << "Original group assignment: " << ToString();
204 
205   int num_slices = replica_to_device_map_.num_slices();
206   if (num_slices == 0) {
207     return errors::InvalidArgument("Unexpectedly empty replica_to_device_map.");
208   }
209 
210   // For each replica group in global replica groups, divide its replicas based
211   // on which slices they come from. Then, for each slice, collect subgroups
212   // from every such division and form a new ReplicaGroup for that slice.
213   std::vector<std::vector<std::vector<int>>> replica_groups_per_host;
214   std::vector<std::vector<std::vector<int>>> replica_groups_per_slice;
215   replica_groups_per_host.resize(num_slices, {});
216   replica_groups_per_slice.resize(num_slices, {});
217 
218   for (const std::vector<int>& replica_group : replica_ids()) {
219     std::vector<std::vector<int>> replica_group_divided_by_host;
220     replica_group_divided_by_host.resize(num_slices, {});
221     std::vector<std::vector<int>> replica_group_divided_by_slice;
222     replica_group_divided_by_slice.resize(num_slices, {});
223 
224     for (int replica_id : replica_group) {
225       // TODO(b/183426911): Use DeviceId::core_id in ReplicaGroup directly for
226       // now. Integrate with device assignment with proper typing.
227       DeviceId device_id = replica_to_device_map_.device_id(replica_id);
228       replica_group_divided_by_host[device_id.slice_id].push_back(replica_id);
229       replica_group_divided_by_slice[device_id.slice_id].push_back(
230           device_id.core_id);
231     }
232 
233     for (int i = 0; i < num_slices; ++i) {
234       if (!replica_group_divided_by_host[i].empty()) {
235         // Host meshes have the same global device and replica IDs as TPU
236         // meshes. Let the first replica in every group do a host collective.
237         replica_groups_per_host[i].push_back(
238             std::vector<int>(1, replica_group_divided_by_host[i].front()));
239       }
240       if (!replica_group_divided_by_slice[i].empty()) {
241         replica_groups_per_slice[i].push_back(
242             std::move(replica_group_divided_by_slice[i]));
243       }
244     }
245   }
246 
247   hosts_.reserve(num_slices);
248   slices_.reserve(num_slices);
249   for (int i = 0; i < num_slices; ++i) {
250     hosts_.push_back(ReplicaGroups(std::move(replica_groups_per_host[i])));
251     slices_.push_back(ReplicaGroups(std::move(replica_groups_per_slice[i])));
252   }
253 
254   VLOG(2) << "Divided group assignment: " << ToString();
255   return OkStatus();
256 }
257 
258 }  // namespace dtensor
259 }  // namespace tensorflow
260