xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/merge_clusters.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/Types.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/IR/Visitors.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Pass/PassManager.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "mlir/Transforms/Passes.h"  // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
48 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
49 #include "tensorflow/dtensor/mlir/layout_parsing.h"
50 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
51 
52 namespace tensorflow {
53 namespace dtensor {
54 namespace {
55 
56 constexpr char kMissingMeshErrorMsg[] =
57     "Failed to extract mesh for DTensorMergeCluster pass. "
58     "All clusters must have specified mesh.";
59 
60 constexpr char kSendRecvKeyPrefix[] = "SendRecvKeyForControlflow_";
61 
62 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)63 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
64                                            Mesh* mesh_output) {
65   auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
66   if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
67 
68   const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
69   if (!mesh_or_null.has_value())
70     return cluster.emitOpError(kMissingMeshErrorMsg);
71 
72   *mesh_output = mesh_or_null.value();
73   return mlir::success();
74 }
75 
76 // Returns all tf_device.ClusterOps nested inside `op`.
FindAllDeviceClusters(mlir::Operation * op)77 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> FindAllDeviceClusters(
78     mlir::Operation* op) {
79   llvm::SmallVector<mlir::tf_device::ClusterOp, 4> nested_clusters;
80   op->walk([&](mlir::tf_device::ClusterOp nested_cluster) {
81     nested_clusters.emplace_back(nested_cluster);
82   });
83   return nested_clusters;
84 }
85 
MergeAttributes(mlir::Operation * op,mlir::DenseIntElementsAttr indices_attr,mlir::ArrayAttr layout_attr,mlir::DenseIntElementsAttr indices_attr2,mlir::ArrayAttr layout_attr2,llvm::SmallVector<int,4> * merged_indices,llvm::SmallVector<mlir::Attribute,4> * merged_layout)86 mlir::LogicalResult MergeAttributes(
87     mlir::Operation* op, mlir::DenseIntElementsAttr indices_attr,
88     mlir::ArrayAttr layout_attr, mlir::DenseIntElementsAttr indices_attr2,
89     mlir::ArrayAttr layout_attr2, llvm::SmallVector<int, 4>* merged_indices,
90     llvm::SmallVector<mlir::Attribute, 4>* merged_layout) {
91   llvm::SmallDenseMap<llvm::APInt, mlir::Attribute> attr_map;
92   attr_map.reserve(indices_attr.size() + indices_attr2.size());
93   for (const auto& data : llvm::zip(indices_attr, layout_attr))
94     attr_map.try_emplace(std::get<0>(data), std::get<1>(data));
95 
96   for (const auto& data : llvm::zip(indices_attr2, layout_attr2)) {
97     const auto& index = std::get<0>(data);
98     const auto& layout = std::get<1>(data);
99     auto result = attr_map.try_emplace(index, layout);
100 
101     if (!result.second && layout != result.first->getSecond()) {
102       return op->emitOpError(
103           "Found conflicting metadata attributes while merging clusters");
104     }
105   }
106 
107   merged_indices->reserve(attr_map.size());
108   merged_layout->reserve(attr_map.size());
109   for (const auto& it : attr_map) {
110     merged_indices->emplace_back(it.first.getSExtValue());
111     merged_layout->emplace_back(it.second);
112   }
113   return mlir::success();
114 }
115 
116 // Merges metadata attribute from `src_cluster` to `target_cluster`. If metadata
117 // attribute exists for both clusters, merge the attributes and verify that
118 // there are no conflicing attributes.
MergeClusterMetadata(mlir::tf_device::ClusterOp src_cluster,mlir::tf_device::ClusterOp target_cluster)119 mlir::LogicalResult MergeClusterMetadata(
120     mlir::tf_device::ClusterOp src_cluster,
121     mlir::tf_device::ClusterOp target_cluster) {
122   if (mlir::failed(ValidateMetadataAttributes(src_cluster)) ||
123       mlir::failed(ValidateMetadataAttributes(target_cluster)))
124     return mlir::failure();
125 
126   mlir::OpBuilder builder(target_cluster);
127 
128   // Extract resource metadata from src/target clusters.
129   auto src_resource_handle_indices_metadata =
130       src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
131           kNewResourceLayoutIndices);
132   auto src_inferred_resource_handle_layouts_metadata =
133       src_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
134 
135   auto target_resource_handle_indices_metadata =
136       target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
137           kNewResourceLayoutIndices);
138   auto target_inferred_resource_handle_layouts_metadata =
139       target_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
140   const bool should_merge_resource_metadata =
141       (src_inferred_resource_handle_layouts_metadata &&
142        src_resource_handle_indices_metadata &&
143        target_inferred_resource_handle_layouts_metadata &&
144        target_resource_handle_indices_metadata);
145   // If only source cluster has metadata, then simply copy the metadata to
146   // target  cluster.
147   if (src_inferred_resource_handle_layouts_metadata &&
148       !target_inferred_resource_handle_layouts_metadata) {
149     target_cluster->setAttr(kNewResourceLayoutIndices,
150                             src_resource_handle_indices_metadata);
151     target_cluster->setAttr(kNewResourceArgLayouts,
152                             src_inferred_resource_handle_layouts_metadata);
153   } else if (should_merge_resource_metadata) {
154     // If both src cluster and target cluster has metadata, merge the metadata
155     // and check if there are no conflicts.
156     llvm::SmallVector<int, 4> merged_resource_indices;
157     llvm::SmallVector<mlir::Attribute, 4> merged_resource_layouts;
158     if (mlir::failed(MergeAttributes(
159             src_cluster, src_resource_handle_indices_metadata,
160             src_inferred_resource_handle_layouts_metadata,
161             target_resource_handle_indices_metadata,
162             target_inferred_resource_handle_layouts_metadata,
163             &merged_resource_indices, &merged_resource_layouts)))
164       return mlir::failure();
165 
166     target_cluster->setAttr(
167         kNewResourceArgLayouts,
168         builder.getArrayAttr(
169             llvm::ArrayRef<mlir::Attribute>(merged_resource_layouts)));
170 
171     target_cluster->setAttr(
172         kNewResourceLayoutIndices,
173         builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_resource_indices)));
174   }
175 
176   // Extract shape metadata from src/target clusters.
177   auto src_shape_layouts =
178       src_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
179   auto src_shape_op_indices =
180       src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
181           kShapeOpInputLayoutIndices);
182   auto target_shape_layouts =
183       target_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
184   auto target_shape_op_indices =
185       target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
186           kShapeOpInputLayoutIndices);
187 
188   const bool should_merge_shape_metadata =
189       (src_shape_layouts && src_shape_op_indices && target_shape_layouts &&
190        target_shape_op_indices);
191 
192   // If only src cluster has shape metadata, copy shape metadata to target
193   // cluster.
194   if (src_shape_layouts && !target_shape_layouts) {
195     target_cluster->setAttr(kShapeOpInputLayoutIndices, src_shape_op_indices);
196     target_cluster->setAttr(kShapeOpInputLayout, src_shape_layouts);
197   } else if (should_merge_shape_metadata) {
198     // If both src/target clusters have shape metadata, merge the shape metadata
199     // and set the merged metadata to target cluster.
200     llvm::SmallVector<int, 4> merged_shape_indices;
201     llvm::SmallVector<mlir::Attribute, 4> merged_shape_layouts;
202     if (mlir::failed(MergeAttributes(
203             src_cluster, src_shape_op_indices, src_shape_layouts,
204             target_shape_op_indices, target_shape_layouts,
205             &merged_shape_indices, &merged_shape_layouts)))
206       return mlir::failure();
207 
208     target_cluster->setAttr(
209         kShapeOpInputLayout,
210         builder.getArrayAttr(
211             llvm::ArrayRef<mlir::Attribute>(merged_shape_layouts)));
212 
213     target_cluster->setAttr(
214         kShapeOpInputLayoutIndices,
215         builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_shape_indices)));
216   }
217 
218   return mlir::success();
219 }
220 
221 // Removes tf_device.Cluster ops if tf_device.Cluster is nested inside another
222 // cluster and it has same mesh specification as parent cluster.
InlineNestedDeviceClusters(mlir::ModuleOp module)223 mlir::LogicalResult InlineNestedDeviceClusters(mlir::ModuleOp module) {
224   auto clusters = FindAllDeviceClusters(module);
225   for (mlir::tf_device::ClusterOp cluster : clusters) {
226     auto parent_cluster =
227         cluster->getParentOfType<mlir::tf_device::ClusterOp>();
228     if (!parent_cluster) continue;
229 
230     Mesh cluster_mesh;
231     if (mlir::failed(ExtractMeshFromCluster(cluster, &cluster_mesh)))
232       return mlir::failure();
233 
234     Mesh parent_cluster_mesh;
235     if (mlir::failed(
236             ExtractMeshFromCluster(parent_cluster, &parent_cluster_mesh)))
237       return mlir::failure();
238 
239     if (parent_cluster_mesh != cluster_mesh) continue;
240 
241     // Found a tf_device.cluster that has same mesh specification as parent
242     // enclosing cluster. Remove the child cluster and move all ops to parent
243     // cluster instead.
244     for (auto it : llvm::zip(cluster.GetBody().getTerminator()->getOperands(),
245                              cluster.results())) {
246       mlir::Value new_value = std::get<0>(it);
247       mlir::Value value_to_replace = std::get<1>(it);
248       value_to_replace.replaceAllUsesWith(new_value);
249     }
250     for (mlir::Operation& op :
251          llvm::make_early_inc_range(cluster.GetBody().without_terminator())) {
252       op.moveBefore(cluster);
253     }
254 
255     if (mlir::failed(MergeClusterMetadata(cluster, parent_cluster)))
256       return mlir::failure();
257 
258     cluster.erase();
259   }
260   return mlir::success();
261 }
262 
263 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
264 // with yield op and an empty block.
CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region,const Mesh & mesh,mlir::OpBuilder & builder,int * num_send_recvs,mlir::MLIRContext * context,mlir::TF::IfRegionOp * cloned_if_region_op)265 void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh,
266                                mlir::OpBuilder& builder, int* num_send_recvs,
267                                mlir::MLIRContext* context,
268                                mlir::TF::IfRegionOp* cloned_if_region_op) {
269   // Create DTensorSend just before tf.If op before creating new cluster. The
270   // DTensorSend op sends the predicate to `mesh` cluster with replicated
271   // layout.
272   mlir::TensorType predicate_tensor_type =
273       if_region.cond().getType().cast<mlir::TensorType>();
274   const std::string send_recv_key =
275       absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs);
276   *num_send_recvs += 1;
277 
278   const Layout target_layout = Layout::ReplicatedOnMesh(mesh, 0);
279   builder.create<mlir::TF::DTensorSend>(
280       if_region.getLoc(), if_region.cond(),
281       builder.getStringAttr(send_recv_key),
282       mlir::dtensor::LayoutAttr::get(context, target_layout));
283 
284   // Create new cluster op that contains cloned if operation.
285   auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
286       if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{});
287   new_cluster.body().push_back(new mlir::Block);
288   builder.setInsertionPointToEnd(&new_cluster.GetBody());
289   auto return_op = builder.create<mlir::tf_device::ReturnOp>(
290       if_region.getLoc(), llvm::SmallVector<mlir::Value, 4>{});
291 
292   // Add DTensorRecv op inside new cluster that receives the cluster.
293   builder.setInsertionPoint(return_op);
294   auto recv_op = builder.create<mlir::TF::DTensorRecv>(
295       if_region.getLoc(), predicate_tensor_type,
296       builder.getStringAttr(send_recv_key),
297       mlir::TF::ShapeAttr::get(context, predicate_tensor_type),
298       mlir::dtensor::LayoutAttr::get(context, target_layout));
299 
300   // Clone tf.IfRegion op inside newly created cluster and make sure
301   // that the predicate tensor is from DTensorRecv op created above.
302   auto host_side_if = builder.create<mlir::TF::IfRegionOp>(
303       if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, recv_op.output(),
304       if_region.is_stateless(),
305       GetUniqueControlflowFnName("cloned_if_then", builder),
306       GetUniqueControlflowFnName("cloned_if_else", builder));
307   *cloned_if_region_op = host_side_if;
308 
309   // Create empty then branch region.
310   auto& then_branch = host_side_if.then_branch();
311   then_branch.push_back(new mlir::Block);
312   builder.setInsertionPointToEnd(&then_branch.front());
313   builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
314                                     /*operands=*/llvm::ArrayRef<mlir::Value>{});
315 
316   // Create empty else branch region.
317   auto& else_branch = host_side_if.else_branch();
318   else_branch.push_back(new mlir::Block);
319   builder.setInsertionPointToEnd(&else_branch.front());
320   builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
321                                     /*operands=*/llvm::ArrayRef<mlir::Value>{});
322   new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
323 }
324 
325 // Verifies that send/recv ops are used for input output of cluster. That is,
326 // cluster should not have any input/output edges.
VerifyClusterInputOutput(mlir::tf_device::ClusterOp cluster)327 mlir::LogicalResult VerifyClusterInputOutput(
328     mlir::tf_device::ClusterOp cluster) {
329   if (cluster.getNumResults() > 0)
330     return cluster->emitOpError(
331         "found nested tf_device.Cluster op with outputs. Nested cluster must "
332         "use send/recv instead.");
333 
334   mlir::LogicalResult result = mlir::success();
335   mlir::visitUsedValuesDefinedAbove(
336       cluster.body(), cluster.body(), [&](mlir::OpOperand* input) {
337         if (!input->get().isa<mlir::BlockArgument>()) {
338           result = cluster.emitOpError(
339               "found nested tf_device.Cluster op with inputs. Nested cluster "
340               "must use send/recv instead.");
341           return;
342         }
343       });
344   return result;
345 }
346 
347 // Returns whether `cluster` is inside then branch of `if_op`.
IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op,mlir::tf_device::ClusterOp cluster)348 bool IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op,
349                           mlir::tf_device::ClusterOp cluster) {
350   assert(if_op->isProperAncestor(cluster));
351   return if_op.then_branch().isAncestor(cluster->getParentRegion());
352 }
353 
354 // Decomposes multi-mesh computation nested inside tf_if operations. See
355 // comments for `DecomposeControlflow()` function for details.
DecomposeIf(mlir::TF::IfRegionOp if_op,mlir::MLIRContext * context,int * num_control_flow_send_recvs)356 mlir::LogicalResult DecomposeIf(mlir::TF::IfRegionOp if_op,
357                                 mlir::MLIRContext* context,
358                                 int* num_control_flow_send_recvs) {
359   auto nested_clusters = FindAllDeviceClusters(if_op);
360   if (nested_clusters.empty()) return mlir::success();
361 
362   for (mlir::tf_device::ClusterOp nested_cluster : nested_clusters) {
363     if (mlir::failed(VerifyClusterInputOutput(nested_cluster)))
364       return mlir::failure();
365 
366     Mesh nested_mesh;
367     if (mlir::failed(ExtractMeshFromCluster(nested_cluster, &nested_mesh)))
368       return mlir::failure();
369 
370     mlir::OpBuilder builder(if_op);
371     mlir::TF::IfRegionOp cloned_if;
372     CloneEmptyIfWithPredicate(if_op, nested_mesh, builder,
373                               num_control_flow_send_recvs, context, &cloned_if);
374 
375     // Find nested clusters in then/else branch of original `if_op` and
376     // move all inner ops inside nested cluster to `tf_cloned` in
377     // corresponding branch.
378     if (IsInsideIfThenBranch(if_op, nested_cluster)) {
379       mlir::Operation* then_branch_terminator =
380           cloned_if.then_branch().begin()->getTerminator();
381       auto& nested_cluster_operations =
382           nested_cluster.GetBody().getOperations();
383       cloned_if.then_branch().begin()->getOperations().splice(
384           then_branch_terminator->getIterator(), nested_cluster_operations,
385           nested_cluster_operations.begin(),
386           std::prev(nested_cluster_operations.end()));
387     } else {
388       mlir::Operation* else_branch_terminator =
389           cloned_if.else_branch().begin()->getTerminator();
390       auto& nested_cluster_operations =
391           nested_cluster.GetBody().getOperations();
392       cloned_if.else_branch().begin()->getOperations().splice(
393           else_branch_terminator->getIterator(), nested_cluster_operations,
394           nested_cluster_operations.begin(),
395           std::prev(nested_cluster_operations.end()));
396     }
397     nested_cluster.erase();
398   }
399   return mlir::success();
400 }
401 
402 // Decomposes controlflows with nested mesh computations. When multi-mesh
403 // computation exists inside control flow operations like tf.If, then
404 // the control flow operations should be replicated to ensure correct execution
405 // semantics.
406 // For example:
407 //
408 // "tf_device.cluster"() ( {
409 //    %1 = "tf.G"() : () -> (tensor<i1>)
410 //    "tf.IfRegion"(%1) ({
411 //      "tf_device.cluster"() ( {
412 //        "tf.D"() {} : () -> ()
413 //        tf_device.return
414 //      }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
415 //
416 //      "tf.Yield"() : () -> ()
417 //    }, {
418 //    }) {is_stateless = false} : (tensor<i1>) -> ()
419 //    tf_device.return
420 //  }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
421 //
422 // Above computation includes TPU device computation that exists inside
423 // tf.If op in CPU mesh. In this case, tf.If op should be replicated to TPU
424 // device computation so that `tf.D` op is executed in sync with CPU side
425 // computation. After transformation in this function, above IR is changed to:
426 //
427 // "tf_device.cluster"() ( {
428 //      %1 = "tf.DTensorRecv"() : () -> tensor<i1>
429 //      "tf.IfRegion"(%1) ( {
430 //        "tf.D"() : () -> ()
431 //        "tf.Yield"() : () -> ()
432 //      },  {
433 //        "tf.Yield"() : () -> ()
434 //      }) {is_stateless = false} : (tensor<i1>) -> ()
435 //      tf_device.return
436 //    }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
437 //
438 // "tf_device.cluster"() ( {
439 //   %1 = "tf.G"() : () -> tensor<i1>
440 //   "tf.DTensorSend"(%1) : (tensor<i1>) -> ()
441 //   "tf.IfRegion"(%1) ( {
442 //     "tf.Yield"() : () -> ()
443 //    },  {
444 //     "tf.Yield"() : () -> ()
445 //    }) {is_stateless = false} : (tensor<i1>) -> ()
446 //    tf_device.return
447 // }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
448 //
449 // Note that:
450 //  1) Control flow is replicated.
451 //  2) DTensorSend/Recv ops are added to transfer predicate tensors for
452 //     control flow operations
DecomposeControlflow(mlir::MLIRContext * context,int * num_control_flow_send_recvs,mlir::ModuleOp module)453 mlir::LogicalResult DecomposeControlflow(mlir::MLIRContext* context,
454                                          int* num_control_flow_send_recvs,
455                                          mlir::ModuleOp module) {
456   llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
457   // Identify all clusters in topological order.
458   module.walk([&](mlir::tf_device::ClusterOp cluster) {
459     clusters.emplace_back(cluster);
460   });
461 
462   for (mlir::tf_device::ClusterOp cluster : clusters) {
463     mlir::WalkResult walk_result = cluster->walk([&](mlir::Operation* op) {
464       if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
465         if (mlir::failed(
466                 DecomposeIf(if_op, context, num_control_flow_send_recvs)))
467           return mlir::WalkResult::interrupt();
468       }
469       return mlir::WalkResult::advance();
470     });
471     if (walk_result.wasInterrupted()) return mlir::failure();
472   }
473 
474   return mlir::success();
475 }
476 
477 // Merges multiple tf_device.clusters with same mesh specification to a single
478 // mesh cluster.
MergeClusters(mlir::ModuleOp module)479 mlir::LogicalResult MergeClusters(mlir::ModuleOp module) {
480   mlir::func::FuncOp main_func =
481       module.lookupSymbol<mlir::func::FuncOp>("main");
482 
483   // Create global cluster for each mesh in entire computation.
484   auto clusters = FindAllDeviceClusters(main_func);
485   mlir::Block& func_block = *main_func.getBody().begin();
486   mlir::OpBuilder builder(&func_block.front());
487   std::map<Mesh, llvm::SmallVector<mlir::tf_device::ClusterOp, 4>> cluster_map;
488   std::vector<Mesh> meshes;
489   for (mlir::tf_device::ClusterOp cluster : clusters) {
490     Mesh mesh;
491     if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
492       return mlir::failure();
493 
494     if (cluster_map.find(mesh) != cluster_map.end()) {
495       cluster_map[mesh].emplace_back(cluster);
496     } else {
497       cluster_map[mesh] =
498           llvm::SmallVector<mlir::tf_device::ClusterOp, 4>{cluster};
499       meshes.push_back(std::move(mesh));
500     }
501   }
502 
503   // Reevaluate if this sort is necessary after b/186804270 is closed.
504   std::sort(meshes.begin(), meshes.end(), [](const Mesh& a, const Mesh& b) {
505     if (a.device_type() != b.device_type()) {
506       return a.device_type() < b.device_type();
507     }
508     return a < b;
509   });
510   for (const Mesh& mesh : meshes) {
511     const auto& mesh_cluster_list = cluster_map[mesh];
512     llvm::SmallVector<mlir::Value, 4> merged_cluster_outputs;
513     llvm::SmallVector<mlir::Value, 4> merged_return_values;
514     llvm::SmallVector<mlir::Type, 4> merged_return_types;
515 
516     for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
517       merged_cluster_outputs.insert(merged_cluster_outputs.end(),
518                                     cluster.results().begin(),
519                                     cluster.results().end());
520 
521       auto return_values = cluster.GetBody().getTerminator()->getOperands();
522       merged_return_values.insert(merged_return_values.end(),
523                                   return_values.begin(), return_values.end());
524 
525       auto return_type = cluster->getResultTypes();
526       merged_return_types.insert(merged_return_types.end(), return_type.begin(),
527                                  return_type.end());
528     }
529 
530     // Create a single cluster op contains merged computations for `mesh`.
531     builder.setInsertionPoint(&func_block.front());
532     auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
533         module.getLoc(), merged_return_types);
534     new_cluster.body().push_back(new mlir::Block);
535     new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
536 
537     // Move all ops inside clusters in cluster mesh to `new_cluster`.
538     for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
539       mlir::Block& cluster_body = cluster.GetBody();
540       for (mlir::Operation& op_to_move :
541            llvm::make_early_inc_range(cluster_body.without_terminator())) {
542         for (mlir::OpOperand& use : op_to_move.getUses()) {
543           auto return_op =
544               llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
545           if (!return_op) continue;
546 
547           mlir::Value output = cluster.getResult(use.getOperandNumber());
548           output.replaceUsesWithIf(use.get(), [](mlir::OpOperand& operand) {
549             return operand.getOwner()
550                        ->getParentOfType<mlir::tf_device::ClusterOp>() !=
551                    nullptr;
552           });
553         }
554         op_to_move.moveBefore(new_cluster.getBody(),
555                               new_cluster.getBody()->end());
556       }
557     }
558 
559     builder.setInsertionPointToEnd(&new_cluster.GetBody());
560     builder.create<mlir::tf_device::ReturnOp>(new_cluster.getLoc(),
561                                               merged_return_values);
562 
563     // Replace return value usages.
564     for (auto it :
565          llvm::zip(merged_cluster_outputs, new_cluster.getResults())) {
566       mlir::Value value_to_replace = std::get<0>(it);
567       mlir::Value new_result_value = std::get<1>(it);
568       value_to_replace.replaceAllUsesWith(new_result_value);
569     }
570 
571     // Erase clusters in cluster_map now that all ops are moved.
572     for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
573       if (mlir::failed(MergeClusterMetadata(cluster, new_cluster)))
574         return mlir::failure();
575 
576       cluster.erase();
577     }
578   }
579 
580   return mlir::success();
581 }
582 
583 // Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
584 // into a single cluster. After this pass, exactly one tf_device.Cluster op
585 // exists for each device mesh.
586 struct DTensorMergeClusters
587     : public DTensorMergeClustersBase<DTensorMergeClusters> {
getDependentDialectstensorflow::dtensor::__anon216d88340111::DTensorMergeClusters588   void getDependentDialects(mlir::DialectRegistry& registry) const override {
589     registry.insert<mlir::dtensor::DTensorDialect>();
590   }
591 
runOnOperationtensorflow::dtensor::__anon216d88340111::DTensorMergeClusters592   void runOnOperation() override {
593     mlir::MLIRContext& context = getContext();
594     mlir::OpBuilder op_builder(&context);
595     auto module = getOperation();
596     if (mlir::failed(InlineNestedDeviceClusters(module)))
597       return signalPassFailure();
598 
599     int num_controlflow_send_recv = 0;
600     if (mlir::failed(
601             DecomposeControlflow(&context, &num_controlflow_send_recv, module)))
602       return signalPassFailure();
603 
604     if (mlir::failed(MergeClusters(module))) return signalPassFailure();
605 
606     llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
607     module.walk([&](mlir::tf_device::ClusterOp cluster) {
608       clusters.emplace_back(cluster);
609     });
610 
611     for (mlir::tf_device::ClusterOp cluster : clusters) {
612       RemoveUnusedClusterResults(cluster);
613     }
614   };
615 };
616 
617 }  // namespace
618 
619 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorMergeClustersPass()620 CreateDTensorMergeClustersPass() {
621   return std::make_unique<DTensorMergeClusters>();
622 }
623 
624 }  // namespace dtensor
625 }  // namespace tensorflow
626