1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/Types.h" // from @llvm-project
26 #include "mlir/IR/Visitors.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/dtensor/cc/constants.h"
33 #include "tensorflow/dtensor/cc/dtensor_utils.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
35 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
36 #include "tensorflow/dtensor/mlir/group_assignment.h"
37 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
38 #include "tensorflow/dtensor/mlir/layout_parsing.h"
39
40 namespace tensorflow {
41 namespace dtensor {
42
43 namespace {
44
45 namespace ops_util = ::mlir::TF::collection_ops_util;
46
47 // Pad the merged tensor shape to multiples of 1024B, so delinearization
48 // skipping optimization in XLA can get activated.
49 constexpr int32 kAllReducePadding = 1024;
50
51 // Returns true if `successor` depends on `predecessor`.
52 // TODO(jiawenhao): Repeatedly computing dependency sets for a large cluster can
53 // get expensive when the number of all-reduces is high. Consider building a
54 // cluster-scope op dependency graph ahead of time to amortize the cost.
DependsOn(mlir::Operation * successor,mlir::Operation * predecessor)55 bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) {
56 llvm::SmallVector<mlir::Operation*, 4> to_visit;
57 llvm::SmallPtrSet<mlir::Operation*, 4> visited;
58 to_visit.push_back(predecessor);
59 while (!to_visit.empty()) {
60 mlir::Operation* producer = to_visit.pop_back_val();
61 if (visited.contains(producer)) continue;
62 visited.insert(producer);
63 if (successor == producer) return true;
64 for (mlir::Operation* user : producer->getUsers()) {
65 if (visited.contains(user)) continue;
66 to_visit.push_back(user);
67 }
68 }
69 return false;
70 }
71
72 // Moves all usages of `a` (direct and transitive) to right after `b` in
73 // `cluster`, preserving the original order of moved ops.
74 // `a` and `b` must be in `cluster`. `a` must appear before `b` originally.
75 // `a` itself is not moved.
76 //
77 // For example, this program:
78 //
79 // tf_device.cluster() ({
80 // %a = tf.A()
81 // %1 = tf.C(%a)
82 // %2 = tf.D(%a)
83 // %3 = tf.E(%1, %2)
84 // %b = tf.B()
85 // %4 = tf.F(%3)
86 // %5 = tf.G(%b)
87 // tf_device.return()
88 // })
89 //
90 // will become this:
91 //
92 // tf_device.cluster() ({
93 // %a = tf.A()
94 // %b = tf.B()
95 // %1 = tf.C(%a)
96 // %2 = tf.D(%a)
97 // %3 = tf.E(%1, %2)
98 // %4 = tf.F(%3)
99 // %5 = tf.G(%b)
100 // tf_device.return()
101 // })
MoveUsagesAfter(mlir::tf_device::ClusterOp cluster,mlir::Operation * a,mlir::Operation * b)102 void MoveUsagesAfter(mlir::tf_device::ClusterOp cluster, mlir::Operation* a,
103 mlir::Operation* b) {
104 llvm::SmallVector<mlir::Operation*, 4> to_visit;
105 llvm::SmallPtrSet<mlir::Operation*, 4> visited;
106 to_visit.push_back(a);
107 while (!to_visit.empty()) {
108 mlir::Operation* producer = to_visit.pop_back_val();
109 if (visited.contains(producer)) continue;
110 visited.insert(producer);
111 for (mlir::Operation* user : producer->getUsers()) {
112 if (visited.contains(user)) continue;
113 to_visit.push_back(user);
114 }
115 }
116
117 llvm::SmallVector<mlir::Operation*, 4> to_move;
118 cluster.GetBody().walk([&](mlir::Operation* op) {
119 if (op != a && visited.contains(op) && op->isBeforeInBlock(b)) {
120 to_move.push_back(op);
121 }
122 });
123
124 mlir::Operation* last = b;
125 for (mlir::Operation* op : to_move) {
126 if (mlir::dyn_cast<mlir::TF::YieldOp>(op)) {
127 LOG(FATAL) << "Should never move YieldOp"; // Crash OK
128 }
129 op->moveAfter(last);
130 last = op;
131 }
132 }
133
134 // Merge all-reduces in the group into one all-reduce.
135 //
136 // Requirements:
137 // - The group should have at least two all-reduces.
138 // - They should be located next to each other in the parent block.
139 // - They should all have the same element type.
140 // - They should all have the same group assignment.
141 //
142 // The merged all-reduce operates on a 1D tensor, whose size is the sum of all
143 // merged all-reduce tensors padded to 1024B. (The padding is necessary for the
144 // XLA delinearization skipping logic.) Each to-be-merged all-reduce flattens
145 // its input tensor and writes the resulting 1D tensor into the corresponding
146 // offset in the merged 1D tensor. After the merged all-reduce is done, the
147 // reverse happens: results are sliced out and reshaped to the original shape.
MergeAllReduceGroup(std::vector<mlir::TF::DTensorAllReduceOp> & all_reduce_group)148 mlir::LogicalResult MergeAllReduceGroup(
149 std::vector<mlir::TF::DTensorAllReduceOp>& all_reduce_group) {
150 // Create the initial all-zero merged tensor.
151 // The merged tensor's size is the sum of all individual all-reduces' sizes.
152 int num_all_reduces = all_reduce_group.size();
153 DCHECK(num_all_reduces > 1)
154 << "All reduce group size expected to be greater than 1.";
155 int total_num_elements = 0;
156 std::vector<llvm::ArrayRef<int64_t>> all_reduce_shapes;
157 all_reduce_shapes.reserve(num_all_reduces);
158 for (mlir::TF::DTensorAllReduceOp& all_reduce : all_reduce_group) {
159 auto all_reduce_ranked_type =
160 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
161 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
162 return all_reduce.emitOpError(llvm::formatv(
163 "requires static shape for DTensorAllReduceOp, but got : {0}",
164 all_reduce_ranked_type));
165 }
166 int num_elements = all_reduce_ranked_type.getNumElements();
167 total_num_elements += num_elements;
168 all_reduce_shapes.push_back(all_reduce_ranked_type.getShape());
169 }
170
171 // Pad the merged tensor shape to multiples of 1024B, so delinearization
172 // skipping optimization in XLA can get activated.
173 if (total_num_elements % kAllReducePadding != 0) {
174 total_num_elements =
175 total_num_elements / kAllReducePadding * kAllReducePadding +
176 kAllReducePadding;
177 }
178
179 // Fill the merged tensor with 0 initially.
180 mlir::OpBuilder builder(all_reduce_group[0]);
181 mlir::Location loc = all_reduce_group[0].getLoc();
182 mlir::Type elem_type = all_reduce_group[0].getType().getElementType();
183 auto zero_scalar = ops_util::CreateScalarConst(0, builder, loc);
184 auto zero_scalar_elem_type = builder.create<mlir::TF::CastOp>(
185 loc, mlir::RankedTensorType::get({}, elem_type), zero_scalar);
186 auto merged = builder.create<mlir::TF::FillOp>(
187 loc, ops_util::GetR1Const({total_num_elements}, builder, loc),
188 zero_scalar_elem_type);
189
190 // Store every all-reduce's input at an offset location in the merged tensor,
191 // as a 1D tensor.
192 int offset_num_elements = 0;
193 std::vector<mlir::Type> flattened_types;
194 flattened_types.reserve(num_all_reduces);
195 mlir::TF::XlaDynamicUpdateSliceOp updated;
196 for (int i = 0; i < all_reduce_group.size(); ++i) {
197 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
198 mlir::Location loc = all_reduce.getLoc();
199 auto all_reduce_ranked_type =
200 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
201 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
202 return all_reduce.emitOpError(llvm::formatv(
203 "requires static shape for DTensorAllReduceOp, but got : {0}",
204 all_reduce_ranked_type));
205 }
206
207 int num_elements = all_reduce_ranked_type.getNumElements();
208 auto flattened = builder.create<mlir::TF::ReshapeOp>(
209 loc, all_reduce.input(),
210 ops_util::GetR1Const({num_elements}, builder, loc));
211 flattened_types.push_back(flattened.getType());
212 auto indices = ops_util::GetR1Const({offset_num_elements}, builder, loc);
213 updated = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>(
214 loc, merged.getType(),
215 /*input=*/i == 0 ? merged.getResult() : updated.getResult(),
216 /*update=*/flattened, indices);
217 offset_num_elements += num_elements;
218 }
219
220 // All-reduce the updated merged tensor.
221 auto merged_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
222 all_reduce_group[0].getLoc(), updated.getType(), updated,
223 all_reduce_group[0].group_assignment(), all_reduce_group[0].reduce_op(),
224 all_reduce_group[0].device_type());
225 SetSingleLayoutOnOp(
226 merged_all_reduce,
227 ExtractSingleLayoutFromOp(all_reduce_group[0]).ValueOrDie().value());
228
229 // Slice out the original all-reduces, and reshape back to the original shape.
230 offset_num_elements = 0;
231 std::vector<mlir::TF::ReshapeOp> replacements;
232 replacements.reserve(num_all_reduces);
233 for (int i = 0; i < all_reduce_group.size(); ++i) {
234 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
235 mlir::Location loc = all_reduce.getLoc();
236 auto all_reduce_ranked_type =
237 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
238 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
239 return all_reduce.emitOpError(llvm::formatv(
240 "requires static shape for DTensorAllReduceOp, but got : {0}",
241 all_reduce_ranked_type));
242 }
243 int num_elements = all_reduce_ranked_type.getNumElements();
244 auto slice = builder.create<mlir::TF::SliceOp>(
245 loc, flattened_types[i], /*input=*/merged_all_reduce,
246 /*begin=*/ops_util::GetR1Const({offset_num_elements}, builder, loc),
247 /*size=*/ops_util::GetR1Const({num_elements}, builder, loc));
248 auto replacement = builder.create<mlir::TF::ReshapeOp>(
249 loc, slice.getResult(),
250 ops_util::GetR1Const(all_reduce_shapes[i], builder, loc));
251 replacements.push_back(replacement);
252 offset_num_elements += num_elements;
253 }
254
255 // Replace usages and clean up.
256 for (int i = 0; i < all_reduce_group.size(); ++i) {
257 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
258 mlir::TF::ReshapeOp& replacement = replacements[i];
259 all_reduce.replaceAllUsesWith(replacement.getResult());
260 all_reduce.erase();
261 }
262 return mlir::success();
263 }
264
265 // Dump the dependencies between AllReduce ops as a DOT graph.
DrawAllReduceDependencies(std::vector<mlir::TF::DTensorAllReduceOp> all_reduces)266 std::string DrawAllReduceDependencies(
267 std::vector<mlir::TF::DTensorAllReduceOp> all_reduces) {
268 std::vector<std::vector<int>> dependents(all_reduces.size(),
269 std::vector<int>());
270 for (int j = 0; j < all_reduces.size(); ++j) {
271 mlir::TF::DTensorAllReduceOp later = all_reduces[j];
272 for (int i = 0; i < j; ++i) {
273 mlir::TF::DTensorAllReduceOp earlier = all_reduces[i];
274 DCHECK(!DependsOn(earlier, later));
275 if (earlier->getBlock() != later->getBlock() ||
276 DependsOn(later, earlier)) {
277 dependents[i].push_back(j);
278 }
279 }
280 }
281 std::string output = "digraph all_reduces {\n";
282 for (int i = 0; i < dependents.size(); i++) {
283 strings::StrAppend(&output, i);
284 strings::StrAppend(&output, "\n");
285 }
286 for (int i = 0; i < dependents.size(); i++) {
287 for (int j : dependents[i]) {
288 strings::StrAppend(&output, i, " -> ", j, "\n");
289 }
290 }
291 output += "}";
292 return output;
293 }
294
295 // Combine cross-slice DTensorAllReduce ops of the same element type and group
296 // assignment into as few groups as possible. Only independent ops can be
297 // combined together.
298 //
299 // For example, this program:
300 //
301 // clang-format off
302 // NOLINTBEGIN(whitespace/line_length)
303 // %0 = "tf_device.cluster"() ({
304 // %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
305 // %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
306 // %3 = "tf.DTensorAllReduce"(%1, %2) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
307 // %4 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
308 // %5 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
309 // %6 = "tf.DTensorAllReduce"(%4, %5) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
310 // %7 = "tf.Add"(%3, %6) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
311 // "tf_device.return"(%7) : (tensor<4x4xf32>) -> ()
312 // }) : () -> tensor<4x4xf32>
313 // NOLINTEND
314 // clang-format on
315 //
316 // will become this:
317 //
318 // clang-format off
319 // NOLINTBEGIN(whitespace/line_length)
320 // %0 = "tf_device.cluster"() ( {
321 // %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
322 // %cst_0 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
323 // %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
324 // %cst_2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
325 // %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
326 // %1 = "tf.Cast"(%cst_3) {Truncate = false} : (tensor<i32>) -> tensor<f32>
327 // %cst_4 = "tf.Const"() {value = dense<1024> : tensor<1xi32>} : () -> tensor<1xi32>
328 // %2 = "tf.Fill"(%cst_4, %1) : (tensor<1xi32>, tensor<f32>) -> tensor<1024xf32>
329 // %cst_5 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
330 // %3 = "tf.Reshape"(%cst, %cst_5) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
331 // %cst_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
332 // %4 = "tf.XlaDynamicUpdateSlice"(%2, %3, %cst_6) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
333 // %cst_7 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
334 // %5 = "tf.Reshape"(%cst_1, %cst_7) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
335 // %cst_8 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
336 // %6 = "tf.XlaDynamicUpdateSlice"(%4, %5, %cst_8) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
337 // %7 = "tf.DTensorAllReduce"(%6, %cst_0) {reduce_op = "Add"} : (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32>
338 // %cst_9 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
339 // %cst_10 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
340 // %8 = "tf.Slice"(%7, %cst_9, %cst_10) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
341 // %cst_11 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
342 // %9 = "tf.Reshape"(%8, %cst_11) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
343 // %cst_12 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
344 // %cst_13 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
345 // %10 = "tf.Slice"(%7, %cst_12, %cst_13) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
346 // %cst_14 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
347 // %11 = "tf.Reshape"(%10, %cst_14) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
348 // %12 = "tf.Add"(%9, %11) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
349 // tf_device.return %12 : tensor<4x4xf32>
350 // }) : () -> tensor<4x4xf32>
351 // NOLINTEND
352 // clang-format on
CombineAllReduceOpsOfSameTypeAndGroupAssignment(mlir::tf_device::ClusterOp cluster,const std::vector<mlir::TF::DTensorAllReduceOp> & all_reduces)353 mlir::LogicalResult CombineAllReduceOpsOfSameTypeAndGroupAssignment(
354 mlir::tf_device::ClusterOp cluster,
355 const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
356 // Drop within-slice all-reduces.
357 std::vector<mlir::TF::DTensorAllReduceOp> cross_slice_all_reduces;
358 for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
359 mlir::DenseIntElementsAttr group_assignment_attr;
360 if (!matchPattern(all_reduce.group_assignment(),
361 m_Constant(&group_assignment_attr))) {
362 return all_reduce.emitOpError("group_assignment should be a constant");
363 }
364 // LINT.IfChange
365 int num_slices = NumClients();
366 int slice_size = kTpuDonutSize;
367 if (group_assignment_attr.getNumElements() < kTpuDonutSize) {
368 DCHECK_EQ(num_slices, 1) << "Num slices expected to be equal to 1.";
369 slice_size = group_assignment_attr.getNumElements();
370 }
371 StatusOr<GroupAssignment> group_assignment = GroupAssignment::FromMLIR(
372 group_assignment_attr,
373 GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(
374 num_slices, slice_size));
375 // LINT.ThenChange(//tensorflow/dtensor/mlir/utils/collective_lowering.cc)
376 if (!group_assignment.ok()) {
377 return all_reduce.emitOpError(
378 llvm::formatv("Failed to create a GroupAssignment due to {0}",
379 group_assignment.status().error_message()));
380 }
381 // Unit tests have only one slice. Always combine all all-reduces in them.
382 if (group_assignment->num_slices() == 1 ||
383 !group_assignment->IsWithinSlices()) {
384 cross_slice_all_reduces.push_back(all_reduce);
385 }
386 }
387
388 // A single op has nothing to combine with.
389 int num_all_reduces = cross_slice_all_reduces.size();
390 if (num_all_reduces <= 1) return mlir::success();
391
392 // Export the all reduces as a DOT graph.
393 VLOG(4) << "Visualizing AllReduce dependencies:\n"
394 << DrawAllReduceDependencies(cross_slice_all_reduces);
395
396 // Build a reverse adjacency matrix from dependents to requirements.
397 std::vector<std::vector<int>> requirements(num_all_reduces,
398 std::vector<int>());
399 for (int i = 0; i < num_all_reduces - 1; ++i) {
400 mlir::TF::DTensorAllReduceOp requirement = cross_slice_all_reduces[i];
401 for (int j = i + 1; j < num_all_reduces; ++j) {
402 mlir::TF::DTensorAllReduceOp dependent = cross_slice_all_reduces[j];
403 DCHECK(
404 !DependsOn(requirement, dependent)); // guaranteed by program order
405 // In this example, all three DTensorAllReduce ops are independent from
406 // each other according to MLIR value use-def chains considered by
407 // DependsOn. However, moving all three to after the WhileRegion and
408 // combine them would break the program.
409 //
410 // %3 = tf.DTensorAllReduce(%1, %2)
411 // %4 = tf.WhileRegion(%1) ({
412 // ^bb0(%arg):
413 // %5 = tf.TooBool(%arg)
414 // tf.Yield(%5)
415 // }, {
416 // %6 = tf.DTensorAllReduce(%1, %2)
417 // tf.Yield(%5)
418 // })
419 // %7 = tf.DTensorAllReduce(%1, %2)
420 //
421 // Therefore, in addition to DependsOn, we also check if two
422 // DTensorAllReduceOps belong to different blocks. If they do, since they
423 // exist in the same ClusterOp, one or both of them must be inside a
424 // control flow region block. We treat them as if there is a dependency
425 // between them.
426 //
427 // In the example above, the second DTensorAllReduceOp would "depend on"
428 // the first one, and the third on the second. This effectively prevents
429 // any two DTensorAllReduce from merging together.
430 if (requirement->getBlock() != dependent->getBlock() ||
431 DependsOn(dependent, requirement)) {
432 requirements[j].push_back(i);
433 }
434 }
435 }
436
437 // Traverse the adjacency matrix layer by layer to find combination groups.
438 std::vector<std::vector<mlir::TF::DTensorAllReduceOp>> all_reduce_groups;
439 std::set<int> fulfilled;
440 while (fulfilled.size() < cross_slice_all_reduces.size()) {
441 std::vector<int> fulfilled_this_layer;
442 for (int j = 0; j < requirements.size(); ++j) {
443 if (fulfilled.count(j) > 0) continue;
444 bool requirements_met = true;
445 for (int i : requirements[j]) {
446 if (fulfilled.count(i) == 0) {
447 requirements_met = false;
448 break;
449 }
450 }
451 if (requirements_met) {
452 fulfilled_this_layer.push_back(j);
453 }
454 }
455 VLOG(4) << "Fulfilled: " << str_util::Join(fulfilled_this_layer, ", ");
456 all_reduce_groups.push_back({});
457 for (int i : fulfilled_this_layer) {
458 fulfilled.insert(i);
459 all_reduce_groups.back().push_back(cross_slice_all_reduces[i]);
460 }
461 }
462 VLOG(2) << num_all_reduces << " all-reduce ops in "
463 << all_reduce_groups.size() << " groups";
464
465 // Move all-reduces in the same group together and combine them.
466 for (auto& all_reduce_group : all_reduce_groups) {
467 int num_all_reduces = all_reduce_group.size();
468 if (num_all_reduces <= 1) continue;
469 mlir::TF::DTensorAllReduceOp final_all_reduce =
470 all_reduce_group[num_all_reduces - 1];
471 for (int i = num_all_reduces - 2; i >= 0; --i) {
472 mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
473 MoveUsagesAfter(cluster, all_reduce, final_all_reduce);
474 }
475 for (int i = 0; i < num_all_reduces - 1; ++i) {
476 mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
477 all_reduce->moveBefore(final_all_reduce);
478 }
479 auto merge_result = MergeAllReduceGroup(all_reduce_group);
480 if (merge_result.failed()) return merge_result;
481 }
482
483 return mlir::success();
484 }
485
486 // Returns true if both group assignments are constant and equal.
same_group_assignments(mlir::Value group_assignment_a,mlir::Value group_assignment_b)487 bool same_group_assignments(mlir::Value group_assignment_a,
488 mlir::Value group_assignment_b) {
489 if (group_assignment_a == group_assignment_b) {
490 return true;
491 }
492 mlir::DenseIntElementsAttr attr_a;
493 if (!matchPattern(group_assignment_a, m_Constant(&attr_a))) {
494 return false;
495 }
496 mlir::DenseIntElementsAttr attr_b;
497 if (!matchPattern(group_assignment_b, m_Constant(&attr_b))) {
498 return false;
499 }
500 if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
501 return false;
502 }
503 return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
504 }
505
506 // Combines DTensorAllReduce ops of the same element type into as few groups as
507 // possible. Only ops with the same group assignment can be combined together.
CombineAllReduceOpsOfSameType(mlir::tf_device::ClusterOp cluster,const std::vector<mlir::TF::DTensorAllReduceOp> & all_reduces)508 mlir::LogicalResult CombineAllReduceOpsOfSameType(
509 mlir::tf_device::ClusterOp cluster,
510 const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
511 // Maintain a list of seen group assignments, sorted by first appearance.
512 // Also find and store all-reduces by group assignment. Use the first
513 // mlir::Value that contains a certain group assignment to represent all the
514 // same group assignments.
515 std::vector<mlir::Value> group_assignments;
516 llvm::DenseMap<mlir::Value, std::vector<mlir::TF::DTensorAllReduceOp>>
517 all_reduces_by_group_assignment;
518 for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
519 mlir::Value group_assignment = all_reduce.group_assignment();
520 bool seen = false;
521 for (mlir::Value seen_group_assignment : group_assignments) {
522 if (same_group_assignments(group_assignment, seen_group_assignment)) {
523 group_assignment = seen_group_assignment;
524 seen = true;
525 break;
526 }
527 }
528 if (!seen) group_assignments.push_back(group_assignment);
529 all_reduces_by_group_assignment[group_assignment].push_back(all_reduce);
530 }
531
532 // Combine all-reduces of the same group assignment in first-appearance order.
533 for (mlir::Value group_assignment : group_assignments) {
534 mlir::LogicalResult result =
535 CombineAllReduceOpsOfSameTypeAndGroupAssignment(
536 cluster, all_reduces_by_group_assignment[group_assignment]);
537 if (mlir::failed(result)) return result;
538 }
539
540 return mlir::success();
541 }
542
543 struct DTensorAllReduceCombineOptimization
544 : public DTensorAllReduceCombineOptimizationBase<
545 DTensorAllReduceCombineOptimization> {
runOnOperationtensorflow::dtensor::__anon7e4f83e10111::DTensorAllReduceCombineOptimization546 void runOnOperation() override {
547 mlir::func::FuncOp function = getOperation();
548 function.walk([&](mlir::tf_device::ClusterOp cluster) {
549 // Maintain a list of seen element types, sorted by first appearance.
550 // Also find and store all-reduces by element type.
551 std::vector<mlir::Type> elem_types;
552 llvm::DenseMap<mlir::Type, std::vector<mlir::TF::DTensorAllReduceOp>>
553 all_reduces_by_elem_type;
554 cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
555 mlir::Type elem_type = all_reduce.getType().getElementType();
556 if (std::find(elem_types.begin(), elem_types.end(), elem_type) ==
557 elem_types.end()) {
558 elem_types.push_back(elem_type);
559 }
560 all_reduces_by_elem_type[elem_type].push_back(all_reduce);
561 });
562
563 // Combine all-reduces of the same element type in first-appearance order.
564 for (mlir::Type elem_type : elem_types) {
565 if (mlir::failed(CombineAllReduceOpsOfSameType(
566 cluster, all_reduces_by_elem_type[elem_type]))) {
567 return signalPassFailure();
568 }
569 }
570 });
571 }
572 };
573
574 } // namespace
575
576 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceCombineOptimization()577 CreateDTensorAllReduceCombineOptimization() {
578 return std::make_unique<DTensorAllReduceCombineOptimization>();
579 }
580
581 } // namespace dtensor
582 } // namespace tensorflow
583