1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/Types.h"  // from @llvm-project
26 #include "mlir/IR/Visitors.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/dtensor/cc/constants.h"
33 #include "tensorflow/dtensor/cc/dtensor_utils.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
35 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
36 #include "tensorflow/dtensor/mlir/group_assignment.h"
37 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
38 #include "tensorflow/dtensor/mlir/layout_parsing.h"
39 
40 namespace tensorflow {
41 namespace dtensor {
42 
43 namespace {
44 
45 namespace ops_util = ::mlir::TF::collection_ops_util;
46 
47 // Pad the merged tensor shape to multiples of 1024B, so delinearization
48 // skipping optimization in XLA can get activated.
49 constexpr int32 kAllReducePadding = 1024;
50 
51 // Returns true if `successor` depends on `predecessor`.
52 // TODO(jiawenhao): Repeatedly computing dependency sets for a large cluster can
53 // get expensive when the number of all-reduces is high. Consider building a
54 // cluster-scope op dependency graph ahead of time to amortize the cost.
DependsOn(mlir::Operation * successor,mlir::Operation * predecessor)55 bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) {
56   llvm::SmallVector<mlir::Operation*, 4> to_visit;
57   llvm::SmallPtrSet<mlir::Operation*, 4> visited;
58   to_visit.push_back(predecessor);
59   while (!to_visit.empty()) {
60     mlir::Operation* producer = to_visit.pop_back_val();
61     if (visited.contains(producer)) continue;
62     visited.insert(producer);
63     if (successor == producer) return true;
64     for (mlir::Operation* user : producer->getUsers()) {
65       if (visited.contains(user)) continue;
66       to_visit.push_back(user);
67     }
68   }
69   return false;
70 }
71 
72 // Moves all usages of `a` (direct and transitive) to right after `b` in
73 // `cluster`, preserving the original order of moved ops.
74 // `a` and `b` must be in `cluster`. `a` must appear before `b` originally.
75 // `a` itself is not moved.
76 //
77 // For example, this program:
78 //
79 // tf_device.cluster() ({
80 //   %a = tf.A()
81 //   %1 = tf.C(%a)
82 //   %2 = tf.D(%a)
83 //   %3 = tf.E(%1, %2)
84 //   %b = tf.B()
85 //   %4 = tf.F(%3)
86 //   %5 = tf.G(%b)
87 //   tf_device.return()
88 // })
89 //
90 // will become this:
91 //
92 // tf_device.cluster() ({
93 //   %a = tf.A()
94 //   %b = tf.B()
95 //   %1 = tf.C(%a)
96 //   %2 = tf.D(%a)
97 //   %3 = tf.E(%1, %2)
98 //   %4 = tf.F(%3)
99 //   %5 = tf.G(%b)
100 //   tf_device.return()
101 // })
MoveUsagesAfter(mlir::tf_device::ClusterOp cluster,mlir::Operation * a,mlir::Operation * b)102 void MoveUsagesAfter(mlir::tf_device::ClusterOp cluster, mlir::Operation* a,
103                      mlir::Operation* b) {
104   llvm::SmallVector<mlir::Operation*, 4> to_visit;
105   llvm::SmallPtrSet<mlir::Operation*, 4> visited;
106   to_visit.push_back(a);
107   while (!to_visit.empty()) {
108     mlir::Operation* producer = to_visit.pop_back_val();
109     if (visited.contains(producer)) continue;
110     visited.insert(producer);
111     for (mlir::Operation* user : producer->getUsers()) {
112       if (visited.contains(user)) continue;
113       to_visit.push_back(user);
114     }
115   }
116 
117   llvm::SmallVector<mlir::Operation*, 4> to_move;
118   cluster.GetBody().walk([&](mlir::Operation* op) {
119     if (op != a && visited.contains(op) && op->isBeforeInBlock(b)) {
120       to_move.push_back(op);
121     }
122   });
123 
124   mlir::Operation* last = b;
125   for (mlir::Operation* op : to_move) {
126     if (mlir::dyn_cast<mlir::TF::YieldOp>(op)) {
127       LOG(FATAL) << "Should never move YieldOp";  // Crash OK
128     }
129     op->moveAfter(last);
130     last = op;
131   }
132 }
133 
134 // Merge all-reduces in the group into one all-reduce.
135 //
136 // Requirements:
137 //   - The group should have at least two all-reduces.
138 //   - They should be located next to each other in the parent block.
139 //   - They should all have the same element type.
140 //   - They should all have the same group assignment.
141 //
142 // The merged all-reduce operates on a 1D tensor, whose size is the sum of all
143 // merged all-reduce tensors padded to 1024B. (The padding is necessary for the
144 // XLA delinearization skipping logic.) Each to-be-merged all-reduce flattens
145 // its input tensor and writes the resulting 1D tensor into the corresponding
146 // offset in the merged 1D tensor. After the merged all-reduce is done, the
147 // reverse happens: results are sliced out and reshaped to the original shape.
MergeAllReduceGroup(std::vector<mlir::TF::DTensorAllReduceOp> & all_reduce_group)148 mlir::LogicalResult MergeAllReduceGroup(
149     std::vector<mlir::TF::DTensorAllReduceOp>& all_reduce_group) {
150   // Create the initial all-zero merged tensor.
151   // The merged tensor's size is the sum of all individual all-reduces' sizes.
152   int num_all_reduces = all_reduce_group.size();
153   DCHECK(num_all_reduces > 1)
154       << "All reduce group size expected to be greater than 1.";
155   int total_num_elements = 0;
156   std::vector<llvm::ArrayRef<int64_t>> all_reduce_shapes;
157   all_reduce_shapes.reserve(num_all_reduces);
158   for (mlir::TF::DTensorAllReduceOp& all_reduce : all_reduce_group) {
159     auto all_reduce_ranked_type =
160         all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
161     if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
162       return all_reduce.emitOpError(llvm::formatv(
163           "requires static shape for DTensorAllReduceOp, but got : {0}",
164           all_reduce_ranked_type));
165     }
166     int num_elements = all_reduce_ranked_type.getNumElements();
167     total_num_elements += num_elements;
168     all_reduce_shapes.push_back(all_reduce_ranked_type.getShape());
169   }
170 
171   // Pad the merged tensor shape to multiples of 1024B, so delinearization
172   // skipping optimization in XLA can get activated.
173   if (total_num_elements % kAllReducePadding != 0) {
174     total_num_elements =
175         total_num_elements / kAllReducePadding * kAllReducePadding +
176         kAllReducePadding;
177   }
178 
179   // Fill the merged tensor with 0 initially.
180   mlir::OpBuilder builder(all_reduce_group[0]);
181   mlir::Location loc = all_reduce_group[0].getLoc();
182   mlir::Type elem_type = all_reduce_group[0].getType().getElementType();
183   auto zero_scalar = ops_util::CreateScalarConst(0, builder, loc);
184   auto zero_scalar_elem_type = builder.create<mlir::TF::CastOp>(
185       loc, mlir::RankedTensorType::get({}, elem_type), zero_scalar);
186   auto merged = builder.create<mlir::TF::FillOp>(
187       loc, ops_util::GetR1Const({total_num_elements}, builder, loc),
188       zero_scalar_elem_type);
189 
190   // Store every all-reduce's input at an offset location in the merged tensor,
191   // as a 1D tensor.
192   int offset_num_elements = 0;
193   std::vector<mlir::Type> flattened_types;
194   flattened_types.reserve(num_all_reduces);
195   mlir::TF::XlaDynamicUpdateSliceOp updated;
196   for (int i = 0; i < all_reduce_group.size(); ++i) {
197     mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
198     mlir::Location loc = all_reduce.getLoc();
199     auto all_reduce_ranked_type =
200         all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
201     if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
202       return all_reduce.emitOpError(llvm::formatv(
203           "requires static shape for DTensorAllReduceOp, but got : {0}",
204           all_reduce_ranked_type));
205     }
206 
207     int num_elements = all_reduce_ranked_type.getNumElements();
208     auto flattened = builder.create<mlir::TF::ReshapeOp>(
209         loc, all_reduce.input(),
210         ops_util::GetR1Const({num_elements}, builder, loc));
211     flattened_types.push_back(flattened.getType());
212     auto indices = ops_util::GetR1Const({offset_num_elements}, builder, loc);
213     updated = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>(
214         loc, merged.getType(),
215         /*input=*/i == 0 ? merged.getResult() : updated.getResult(),
216         /*update=*/flattened, indices);
217     offset_num_elements += num_elements;
218   }
219 
220   // All-reduce the updated merged tensor.
221   auto merged_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
222       all_reduce_group[0].getLoc(), updated.getType(), updated,
223       all_reduce_group[0].group_assignment(), all_reduce_group[0].reduce_op(),
224       all_reduce_group[0].device_type());
225   SetSingleLayoutOnOp(
226       merged_all_reduce,
227       ExtractSingleLayoutFromOp(all_reduce_group[0]).ValueOrDie().value());
228 
229   // Slice out the original all-reduces, and reshape back to the original shape.
230   offset_num_elements = 0;
231   std::vector<mlir::TF::ReshapeOp> replacements;
232   replacements.reserve(num_all_reduces);
233   for (int i = 0; i < all_reduce_group.size(); ++i) {
234     mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
235     mlir::Location loc = all_reduce.getLoc();
236     auto all_reduce_ranked_type =
237         all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
238     if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
239       return all_reduce.emitOpError(llvm::formatv(
240           "requires static shape for DTensorAllReduceOp, but got : {0}",
241           all_reduce_ranked_type));
242     }
243     int num_elements = all_reduce_ranked_type.getNumElements();
244     auto slice = builder.create<mlir::TF::SliceOp>(
245         loc, flattened_types[i], /*input=*/merged_all_reduce,
246         /*begin=*/ops_util::GetR1Const({offset_num_elements}, builder, loc),
247         /*size=*/ops_util::GetR1Const({num_elements}, builder, loc));
248     auto replacement = builder.create<mlir::TF::ReshapeOp>(
249         loc, slice.getResult(),
250         ops_util::GetR1Const(all_reduce_shapes[i], builder, loc));
251     replacements.push_back(replacement);
252     offset_num_elements += num_elements;
253   }
254 
255   // Replace usages and clean up.
256   for (int i = 0; i < all_reduce_group.size(); ++i) {
257     mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
258     mlir::TF::ReshapeOp& replacement = replacements[i];
259     all_reduce.replaceAllUsesWith(replacement.getResult());
260     all_reduce.erase();
261   }
262   return mlir::success();
263 }
264 
265 // Dump the dependencies between AllReduce ops as a DOT graph.
DrawAllReduceDependencies(std::vector<mlir::TF::DTensorAllReduceOp> all_reduces)266 std::string DrawAllReduceDependencies(
267     std::vector<mlir::TF::DTensorAllReduceOp> all_reduces) {
268   std::vector<std::vector<int>> dependents(all_reduces.size(),
269                                            std::vector<int>());
270   for (int j = 0; j < all_reduces.size(); ++j) {
271     mlir::TF::DTensorAllReduceOp later = all_reduces[j];
272     for (int i = 0; i < j; ++i) {
273       mlir::TF::DTensorAllReduceOp earlier = all_reduces[i];
274       DCHECK(!DependsOn(earlier, later));
275       if (earlier->getBlock() != later->getBlock() ||
276           DependsOn(later, earlier)) {
277         dependents[i].push_back(j);
278       }
279     }
280   }
281   std::string output = "digraph all_reduces {\n";
282   for (int i = 0; i < dependents.size(); i++) {
283     strings::StrAppend(&output, i);
284     strings::StrAppend(&output, "\n");
285   }
286   for (int i = 0; i < dependents.size(); i++) {
287     for (int j : dependents[i]) {
288       strings::StrAppend(&output, i, " -> ", j, "\n");
289     }
290   }
291   output += "}";
292   return output;
293 }
294 
295 // Combine cross-slice DTensorAllReduce ops of the same element type and group
296 // assignment into as few groups as possible. Only independent ops can be
297 // combined together.
298 //
299 // For example, this program:
300 //
301 // clang-format off
302 // NOLINTBEGIN(whitespace/line_length)
303 // %0 = "tf_device.cluster"() ({
304 //   %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
305 //   %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
306 //   %3 = "tf.DTensorAllReduce"(%1, %2) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
307 //   %4 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
308 //   %5 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
309 //   %6 = "tf.DTensorAllReduce"(%4, %5) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
310 //   %7 = "tf.Add"(%3, %6) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
311 //   "tf_device.return"(%7) : (tensor<4x4xf32>) -> ()
312 // }) : () -> tensor<4x4xf32>
313 // NOLINTEND
314 // clang-format on
315 //
316 // will become this:
317 //
318 // clang-format off
319 // NOLINTBEGIN(whitespace/line_length)
320 // %0 = "tf_device.cluster"() ( {
321 //   %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
322 //   %cst_0 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
323 //   %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
324 //   %cst_2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
325 //   %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
326 //   %1 = "tf.Cast"(%cst_3) {Truncate = false} : (tensor<i32>) -> tensor<f32>
327 //   %cst_4 = "tf.Const"() {value = dense<1024> : tensor<1xi32>} : () -> tensor<1xi32>
328 //   %2 = "tf.Fill"(%cst_4, %1) : (tensor<1xi32>, tensor<f32>) -> tensor<1024xf32>
329 //   %cst_5 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
330 //   %3 = "tf.Reshape"(%cst, %cst_5) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
331 //   %cst_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
332 //   %4 = "tf.XlaDynamicUpdateSlice"(%2, %3, %cst_6) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
333 //   %cst_7 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
334 //   %5 = "tf.Reshape"(%cst_1, %cst_7) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
335 //   %cst_8 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
336 //   %6 = "tf.XlaDynamicUpdateSlice"(%4, %5, %cst_8) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
337 //   %7 = "tf.DTensorAllReduce"(%6, %cst_0) {reduce_op = "Add"} : (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32>
338 //   %cst_9 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
339 //   %cst_10 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
340 //   %8 = "tf.Slice"(%7, %cst_9, %cst_10) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
341 //   %cst_11 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
342 //   %9 = "tf.Reshape"(%8, %cst_11) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
343 //   %cst_12 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
344 //   %cst_13 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
345 //   %10 = "tf.Slice"(%7, %cst_12, %cst_13) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
346 //   %cst_14 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
347 //   %11 = "tf.Reshape"(%10, %cst_14) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
348 //   %12 = "tf.Add"(%9, %11) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
349 //   tf_device.return %12 : tensor<4x4xf32>
350 // }) : () -> tensor<4x4xf32>
351 // NOLINTEND
352 // clang-format on
CombineAllReduceOpsOfSameTypeAndGroupAssignment(mlir::tf_device::ClusterOp cluster,const std::vector<mlir::TF::DTensorAllReduceOp> & all_reduces)353 mlir::LogicalResult CombineAllReduceOpsOfSameTypeAndGroupAssignment(
354     mlir::tf_device::ClusterOp cluster,
355     const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
356   // Drop within-slice all-reduces.
357   std::vector<mlir::TF::DTensorAllReduceOp> cross_slice_all_reduces;
358   for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
359     mlir::DenseIntElementsAttr group_assignment_attr;
360     if (!matchPattern(all_reduce.group_assignment(),
361                       m_Constant(&group_assignment_attr))) {
362       return all_reduce.emitOpError("group_assignment should be a constant");
363     }
364     // LINT.IfChange
365     int num_slices = NumClients();
366     int slice_size = kTpuDonutSize;
367     if (group_assignment_attr.getNumElements() < kTpuDonutSize) {
368       DCHECK_EQ(num_slices, 1) << "Num slices expected to be equal to 1.";
369       slice_size = group_assignment_attr.getNumElements();
370     }
371     StatusOr<GroupAssignment> group_assignment = GroupAssignment::FromMLIR(
372         group_assignment_attr,
373         GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(
374             num_slices, slice_size));
375     // LINT.ThenChange(//tensorflow/dtensor/mlir/utils/collective_lowering.cc)
376     if (!group_assignment.ok()) {
377       return all_reduce.emitOpError(
378           llvm::formatv("Failed to create a GroupAssignment due to {0}",
379                         group_assignment.status().error_message()));
380     }
381     // Unit tests have only one slice. Always combine all all-reduces in them.
382     if (group_assignment->num_slices() == 1 ||
383         !group_assignment->IsWithinSlices()) {
384       cross_slice_all_reduces.push_back(all_reduce);
385     }
386   }
387 
388   // A single op has nothing to combine with.
389   int num_all_reduces = cross_slice_all_reduces.size();
390   if (num_all_reduces <= 1) return mlir::success();
391 
392   // Export the all reduces as a DOT graph.
393   VLOG(4) << "Visualizing AllReduce dependencies:\n"
394           << DrawAllReduceDependencies(cross_slice_all_reduces);
395 
396   // Build a reverse adjacency matrix from dependents to requirements.
397   std::vector<std::vector<int>> requirements(num_all_reduces,
398                                              std::vector<int>());
399   for (int i = 0; i < num_all_reduces - 1; ++i) {
400     mlir::TF::DTensorAllReduceOp requirement = cross_slice_all_reduces[i];
401     for (int j = i + 1; j < num_all_reduces; ++j) {
402       mlir::TF::DTensorAllReduceOp dependent = cross_slice_all_reduces[j];
403       DCHECK(
404           !DependsOn(requirement, dependent));  // guaranteed by program order
405       // In this example, all three DTensorAllReduce ops are independent from
406       // each other according to MLIR value use-def chains considered by
407       // DependsOn. However, moving all three to after the WhileRegion and
408       // combine them would break the program.
409       //
410       // %3 = tf.DTensorAllReduce(%1, %2)
411       // %4 = tf.WhileRegion(%1) ({
412       // ^bb0(%arg):
413       //   %5 = tf.TooBool(%arg)
414       //   tf.Yield(%5)
415       // }, {
416       //   %6 = tf.DTensorAllReduce(%1, %2)
417       //   tf.Yield(%5)
418       // })
419       // %7 = tf.DTensorAllReduce(%1, %2)
420       //
421       // Therefore, in addition to DependsOn, we also check if two
422       // DTensorAllReduceOps belong to different blocks. If they do, since they
423       // exist in the same ClusterOp, one or both of them must be inside a
424       // control flow region block. We treat them as if there is a dependency
425       // between them.
426       //
427       // In the example above, the second DTensorAllReduceOp would "depend on"
428       // the first one, and the third on the second. This effectively prevents
429       // any two DTensorAllReduce from merging together.
430       if (requirement->getBlock() != dependent->getBlock() ||
431           DependsOn(dependent, requirement)) {
432         requirements[j].push_back(i);
433       }
434     }
435   }
436 
437   // Traverse the adjacency matrix layer by layer to find combination groups.
438   std::vector<std::vector<mlir::TF::DTensorAllReduceOp>> all_reduce_groups;
439   std::set<int> fulfilled;
440   while (fulfilled.size() < cross_slice_all_reduces.size()) {
441     std::vector<int> fulfilled_this_layer;
442     for (int j = 0; j < requirements.size(); ++j) {
443       if (fulfilled.count(j) > 0) continue;
444       bool requirements_met = true;
445       for (int i : requirements[j]) {
446         if (fulfilled.count(i) == 0) {
447           requirements_met = false;
448           break;
449         }
450       }
451       if (requirements_met) {
452         fulfilled_this_layer.push_back(j);
453       }
454     }
455     VLOG(4) << "Fulfilled: " << str_util::Join(fulfilled_this_layer, ", ");
456     all_reduce_groups.push_back({});
457     for (int i : fulfilled_this_layer) {
458       fulfilled.insert(i);
459       all_reduce_groups.back().push_back(cross_slice_all_reduces[i]);
460     }
461   }
462   VLOG(2) << num_all_reduces << " all-reduce ops in "
463           << all_reduce_groups.size() << " groups";
464 
465   // Move all-reduces in the same group together and combine them.
466   for (auto& all_reduce_group : all_reduce_groups) {
467     int num_all_reduces = all_reduce_group.size();
468     if (num_all_reduces <= 1) continue;
469     mlir::TF::DTensorAllReduceOp final_all_reduce =
470         all_reduce_group[num_all_reduces - 1];
471     for (int i = num_all_reduces - 2; i >= 0; --i) {
472       mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
473       MoveUsagesAfter(cluster, all_reduce, final_all_reduce);
474     }
475     for (int i = 0; i < num_all_reduces - 1; ++i) {
476       mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
477       all_reduce->moveBefore(final_all_reduce);
478     }
479     auto merge_result = MergeAllReduceGroup(all_reduce_group);
480     if (merge_result.failed()) return merge_result;
481   }
482 
483   return mlir::success();
484 }
485 
486 // Returns true if both group assignments are constant and equal.
same_group_assignments(mlir::Value group_assignment_a,mlir::Value group_assignment_b)487 bool same_group_assignments(mlir::Value group_assignment_a,
488                             mlir::Value group_assignment_b) {
489   if (group_assignment_a == group_assignment_b) {
490     return true;
491   }
492   mlir::DenseIntElementsAttr attr_a;
493   if (!matchPattern(group_assignment_a, m_Constant(&attr_a))) {
494     return false;
495   }
496   mlir::DenseIntElementsAttr attr_b;
497   if (!matchPattern(group_assignment_b, m_Constant(&attr_b))) {
498     return false;
499   }
500   if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
501     return false;
502   }
503   return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
504 }
505 
506 // Combines DTensorAllReduce ops of the same element type into as few groups as
507 // possible. Only ops with the same group assignment can be combined together.
CombineAllReduceOpsOfSameType(mlir::tf_device::ClusterOp cluster,const std::vector<mlir::TF::DTensorAllReduceOp> & all_reduces)508 mlir::LogicalResult CombineAllReduceOpsOfSameType(
509     mlir::tf_device::ClusterOp cluster,
510     const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
511   // Maintain a list of seen group assignments, sorted by first appearance.
512   // Also find and store all-reduces by group assignment. Use the first
513   // mlir::Value that contains a certain group assignment to represent all the
514   // same group assignments.
515   std::vector<mlir::Value> group_assignments;
516   llvm::DenseMap<mlir::Value, std::vector<mlir::TF::DTensorAllReduceOp>>
517       all_reduces_by_group_assignment;
518   for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
519     mlir::Value group_assignment = all_reduce.group_assignment();
520     bool seen = false;
521     for (mlir::Value seen_group_assignment : group_assignments) {
522       if (same_group_assignments(group_assignment, seen_group_assignment)) {
523         group_assignment = seen_group_assignment;
524         seen = true;
525         break;
526       }
527     }
528     if (!seen) group_assignments.push_back(group_assignment);
529     all_reduces_by_group_assignment[group_assignment].push_back(all_reduce);
530   }
531 
532   // Combine all-reduces of the same group assignment in first-appearance order.
533   for (mlir::Value group_assignment : group_assignments) {
534     mlir::LogicalResult result =
535         CombineAllReduceOpsOfSameTypeAndGroupAssignment(
536             cluster, all_reduces_by_group_assignment[group_assignment]);
537     if (mlir::failed(result)) return result;
538   }
539 
540   return mlir::success();
541 }
542 
543 struct DTensorAllReduceCombineOptimization
544     : public DTensorAllReduceCombineOptimizationBase<
545           DTensorAllReduceCombineOptimization> {
runOnOperationtensorflow::dtensor::__anon7e4f83e10111::DTensorAllReduceCombineOptimization546   void runOnOperation() override {
547     mlir::func::FuncOp function = getOperation();
548     function.walk([&](mlir::tf_device::ClusterOp cluster) {
549       // Maintain a list of seen element types, sorted by first appearance.
550       // Also find and store all-reduces by element type.
551       std::vector<mlir::Type> elem_types;
552       llvm::DenseMap<mlir::Type, std::vector<mlir::TF::DTensorAllReduceOp>>
553           all_reduces_by_elem_type;
554       cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
555         mlir::Type elem_type = all_reduce.getType().getElementType();
556         if (std::find(elem_types.begin(), elem_types.end(), elem_type) ==
557             elem_types.end()) {
558           elem_types.push_back(elem_type);
559         }
560         all_reduces_by_elem_type[elem_type].push_back(all_reduce);
561       });
562 
563       // Combine all-reduces of the same element type in first-appearance order.
564       for (mlir::Type elem_type : elem_types) {
565         if (mlir::failed(CombineAllReduceOpsOfSameType(
566                 cluster, all_reduces_by_elem_type[elem_type]))) {
567           return signalPassFailure();
568         }
569       }
570     });
571   }
572 };
573 
574 }  // namespace
575 
576 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorAllReduceCombineOptimization()577 CreateDTensorAllReduceCombineOptimization() {
578   return std::make_unique<DTensorAllReduceCombineOptimization>();
579 }
580 
581 }  // namespace dtensor
582 }  // namespace tensorflow
583