1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <string>
18 #include <utility>
19
20 #include "absl/strings/str_cat.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Diagnostics.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/Types.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/IR/Visitors.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Pass/PassManager.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "mlir/Transforms/Passes.h" // from @llvm-project
39 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
48 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
49 #include "tensorflow/dtensor/mlir/layout_parsing.h"
50 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
51
52 namespace tensorflow {
53 namespace dtensor {
54 namespace {
55
56 constexpr char kMissingMeshErrorMsg[] =
57 "Failed to extract mesh for DTensorMergeCluster pass. "
58 "All clusters must have specified mesh.";
59
60 constexpr char kSendRecvKeyPrefix[] = "SendRecvKeyForControlflow_";
61
62 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)63 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
64 Mesh* mesh_output) {
65 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
66 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
67
68 const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
69 if (!mesh_or_null.has_value())
70 return cluster.emitOpError(kMissingMeshErrorMsg);
71
72 *mesh_output = mesh_or_null.value();
73 return mlir::success();
74 }
75
76 // Returns all tf_device.ClusterOps nested inside `op`.
FindAllDeviceClusters(mlir::Operation * op)77 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> FindAllDeviceClusters(
78 mlir::Operation* op) {
79 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> nested_clusters;
80 op->walk([&](mlir::tf_device::ClusterOp nested_cluster) {
81 nested_clusters.emplace_back(nested_cluster);
82 });
83 return nested_clusters;
84 }
85
MergeAttributes(mlir::Operation * op,mlir::DenseIntElementsAttr indices_attr,mlir::ArrayAttr layout_attr,mlir::DenseIntElementsAttr indices_attr2,mlir::ArrayAttr layout_attr2,llvm::SmallVector<int,4> * merged_indices,llvm::SmallVector<mlir::Attribute,4> * merged_layout)86 mlir::LogicalResult MergeAttributes(
87 mlir::Operation* op, mlir::DenseIntElementsAttr indices_attr,
88 mlir::ArrayAttr layout_attr, mlir::DenseIntElementsAttr indices_attr2,
89 mlir::ArrayAttr layout_attr2, llvm::SmallVector<int, 4>* merged_indices,
90 llvm::SmallVector<mlir::Attribute, 4>* merged_layout) {
91 llvm::SmallDenseMap<llvm::APInt, mlir::Attribute> attr_map;
92 attr_map.reserve(indices_attr.size() + indices_attr2.size());
93 for (const auto& data : llvm::zip(indices_attr, layout_attr))
94 attr_map.try_emplace(std::get<0>(data), std::get<1>(data));
95
96 for (const auto& data : llvm::zip(indices_attr2, layout_attr2)) {
97 const auto& index = std::get<0>(data);
98 const auto& layout = std::get<1>(data);
99 auto result = attr_map.try_emplace(index, layout);
100
101 if (!result.second && layout != result.first->getSecond()) {
102 return op->emitOpError(
103 "Found conflicting metadata attributes while merging clusters");
104 }
105 }
106
107 merged_indices->reserve(attr_map.size());
108 merged_layout->reserve(attr_map.size());
109 for (const auto& it : attr_map) {
110 merged_indices->emplace_back(it.first.getSExtValue());
111 merged_layout->emplace_back(it.second);
112 }
113 return mlir::success();
114 }
115
116 // Merges metadata attribute from `src_cluster` to `target_cluster`. If metadata
117 // attribute exists for both clusters, merge the attributes and verify that
118 // there are no conflicing attributes.
MergeClusterMetadata(mlir::tf_device::ClusterOp src_cluster,mlir::tf_device::ClusterOp target_cluster)119 mlir::LogicalResult MergeClusterMetadata(
120 mlir::tf_device::ClusterOp src_cluster,
121 mlir::tf_device::ClusterOp target_cluster) {
122 if (mlir::failed(ValidateMetadataAttributes(src_cluster)) ||
123 mlir::failed(ValidateMetadataAttributes(target_cluster)))
124 return mlir::failure();
125
126 mlir::OpBuilder builder(target_cluster);
127
128 // Extract resource metadata from src/target clusters.
129 auto src_resource_handle_indices_metadata =
130 src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
131 kNewResourceLayoutIndices);
132 auto src_inferred_resource_handle_layouts_metadata =
133 src_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
134
135 auto target_resource_handle_indices_metadata =
136 target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
137 kNewResourceLayoutIndices);
138 auto target_inferred_resource_handle_layouts_metadata =
139 target_cluster->getAttrOfType<mlir::ArrayAttr>(kNewResourceArgLayouts);
140 const bool should_merge_resource_metadata =
141 (src_inferred_resource_handle_layouts_metadata &&
142 src_resource_handle_indices_metadata &&
143 target_inferred_resource_handle_layouts_metadata &&
144 target_resource_handle_indices_metadata);
145 // If only source cluster has metadata, then simply copy the metadata to
146 // target cluster.
147 if (src_inferred_resource_handle_layouts_metadata &&
148 !target_inferred_resource_handle_layouts_metadata) {
149 target_cluster->setAttr(kNewResourceLayoutIndices,
150 src_resource_handle_indices_metadata);
151 target_cluster->setAttr(kNewResourceArgLayouts,
152 src_inferred_resource_handle_layouts_metadata);
153 } else if (should_merge_resource_metadata) {
154 // If both src cluster and target cluster has metadata, merge the metadata
155 // and check if there are no conflicts.
156 llvm::SmallVector<int, 4> merged_resource_indices;
157 llvm::SmallVector<mlir::Attribute, 4> merged_resource_layouts;
158 if (mlir::failed(MergeAttributes(
159 src_cluster, src_resource_handle_indices_metadata,
160 src_inferred_resource_handle_layouts_metadata,
161 target_resource_handle_indices_metadata,
162 target_inferred_resource_handle_layouts_metadata,
163 &merged_resource_indices, &merged_resource_layouts)))
164 return mlir::failure();
165
166 target_cluster->setAttr(
167 kNewResourceArgLayouts,
168 builder.getArrayAttr(
169 llvm::ArrayRef<mlir::Attribute>(merged_resource_layouts)));
170
171 target_cluster->setAttr(
172 kNewResourceLayoutIndices,
173 builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_resource_indices)));
174 }
175
176 // Extract shape metadata from src/target clusters.
177 auto src_shape_layouts =
178 src_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
179 auto src_shape_op_indices =
180 src_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
181 kShapeOpInputLayoutIndices);
182 auto target_shape_layouts =
183 target_cluster->getAttrOfType<mlir::ArrayAttr>(kShapeOpInputLayout);
184 auto target_shape_op_indices =
185 target_cluster->getAttrOfType<mlir::DenseIntElementsAttr>(
186 kShapeOpInputLayoutIndices);
187
188 const bool should_merge_shape_metadata =
189 (src_shape_layouts && src_shape_op_indices && target_shape_layouts &&
190 target_shape_op_indices);
191
192 // If only src cluster has shape metadata, copy shape metadata to target
193 // cluster.
194 if (src_shape_layouts && !target_shape_layouts) {
195 target_cluster->setAttr(kShapeOpInputLayoutIndices, src_shape_op_indices);
196 target_cluster->setAttr(kShapeOpInputLayout, src_shape_layouts);
197 } else if (should_merge_shape_metadata) {
198 // If both src/target clusters have shape metadata, merge the shape metadata
199 // and set the merged metadata to target cluster.
200 llvm::SmallVector<int, 4> merged_shape_indices;
201 llvm::SmallVector<mlir::Attribute, 4> merged_shape_layouts;
202 if (mlir::failed(MergeAttributes(
203 src_cluster, src_shape_op_indices, src_shape_layouts,
204 target_shape_op_indices, target_shape_layouts,
205 &merged_shape_indices, &merged_shape_layouts)))
206 return mlir::failure();
207
208 target_cluster->setAttr(
209 kShapeOpInputLayout,
210 builder.getArrayAttr(
211 llvm::ArrayRef<mlir::Attribute>(merged_shape_layouts)));
212
213 target_cluster->setAttr(
214 kShapeOpInputLayoutIndices,
215 builder.getI32VectorAttr(llvm::ArrayRef<int>(merged_shape_indices)));
216 }
217
218 return mlir::success();
219 }
220
221 // Removes tf_device.Cluster ops if tf_device.Cluster is nested inside another
222 // cluster and it has same mesh specification as parent cluster.
InlineNestedDeviceClusters(mlir::ModuleOp module)223 mlir::LogicalResult InlineNestedDeviceClusters(mlir::ModuleOp module) {
224 auto clusters = FindAllDeviceClusters(module);
225 for (mlir::tf_device::ClusterOp cluster : clusters) {
226 auto parent_cluster =
227 cluster->getParentOfType<mlir::tf_device::ClusterOp>();
228 if (!parent_cluster) continue;
229
230 Mesh cluster_mesh;
231 if (mlir::failed(ExtractMeshFromCluster(cluster, &cluster_mesh)))
232 return mlir::failure();
233
234 Mesh parent_cluster_mesh;
235 if (mlir::failed(
236 ExtractMeshFromCluster(parent_cluster, &parent_cluster_mesh)))
237 return mlir::failure();
238
239 if (parent_cluster_mesh != cluster_mesh) continue;
240
241 // Found a tf_device.cluster that has same mesh specification as parent
242 // enclosing cluster. Remove the child cluster and move all ops to parent
243 // cluster instead.
244 for (auto it : llvm::zip(cluster.GetBody().getTerminator()->getOperands(),
245 cluster.results())) {
246 mlir::Value new_value = std::get<0>(it);
247 mlir::Value value_to_replace = std::get<1>(it);
248 value_to_replace.replaceAllUsesWith(new_value);
249 }
250 for (mlir::Operation& op :
251 llvm::make_early_inc_range(cluster.GetBody().without_terminator())) {
252 op.moveBefore(cluster);
253 }
254
255 if (mlir::failed(MergeClusterMetadata(cluster, parent_cluster)))
256 return mlir::failure();
257
258 cluster.erase();
259 }
260 return mlir::success();
261 }
262
263 // Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
264 // with yield op and an empty block.
CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region,const Mesh & mesh,mlir::OpBuilder & builder,int * num_send_recvs,mlir::MLIRContext * context,mlir::TF::IfRegionOp * cloned_if_region_op)265 void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh,
266 mlir::OpBuilder& builder, int* num_send_recvs,
267 mlir::MLIRContext* context,
268 mlir::TF::IfRegionOp* cloned_if_region_op) {
269 // Create DTensorSend just before tf.If op before creating new cluster. The
270 // DTensorSend op sends the predicate to `mesh` cluster with replicated
271 // layout.
272 mlir::TensorType predicate_tensor_type =
273 if_region.cond().getType().cast<mlir::TensorType>();
274 const std::string send_recv_key =
275 absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs);
276 *num_send_recvs += 1;
277
278 const Layout target_layout = Layout::ReplicatedOnMesh(mesh, 0);
279 builder.create<mlir::TF::DTensorSend>(
280 if_region.getLoc(), if_region.cond(),
281 builder.getStringAttr(send_recv_key),
282 mlir::dtensor::LayoutAttr::get(context, target_layout));
283
284 // Create new cluster op that contains cloned if operation.
285 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
286 if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{});
287 new_cluster.body().push_back(new mlir::Block);
288 builder.setInsertionPointToEnd(&new_cluster.GetBody());
289 auto return_op = builder.create<mlir::tf_device::ReturnOp>(
290 if_region.getLoc(), llvm::SmallVector<mlir::Value, 4>{});
291
292 // Add DTensorRecv op inside new cluster that receives the cluster.
293 builder.setInsertionPoint(return_op);
294 auto recv_op = builder.create<mlir::TF::DTensorRecv>(
295 if_region.getLoc(), predicate_tensor_type,
296 builder.getStringAttr(send_recv_key),
297 mlir::TF::ShapeAttr::get(context, predicate_tensor_type),
298 mlir::dtensor::LayoutAttr::get(context, target_layout));
299
300 // Clone tf.IfRegion op inside newly created cluster and make sure
301 // that the predicate tensor is from DTensorRecv op created above.
302 auto host_side_if = builder.create<mlir::TF::IfRegionOp>(
303 if_region.getLoc(), llvm::SmallVector<mlir::Type, 4>{}, recv_op.output(),
304 if_region.is_stateless(),
305 GetUniqueControlflowFnName("cloned_if_then", builder),
306 GetUniqueControlflowFnName("cloned_if_else", builder));
307 *cloned_if_region_op = host_side_if;
308
309 // Create empty then branch region.
310 auto& then_branch = host_side_if.then_branch();
311 then_branch.push_back(new mlir::Block);
312 builder.setInsertionPointToEnd(&then_branch.front());
313 builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
314 /*operands=*/llvm::ArrayRef<mlir::Value>{});
315
316 // Create empty else branch region.
317 auto& else_branch = host_side_if.else_branch();
318 else_branch.push_back(new mlir::Block);
319 builder.setInsertionPointToEnd(&else_branch.front());
320 builder.create<mlir::TF::YieldOp>(if_region.getLoc(),
321 /*operands=*/llvm::ArrayRef<mlir::Value>{});
322 new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
323 }
324
325 // Verifies that send/recv ops are used for input output of cluster. That is,
326 // cluster should not have any input/output edges.
VerifyClusterInputOutput(mlir::tf_device::ClusterOp cluster)327 mlir::LogicalResult VerifyClusterInputOutput(
328 mlir::tf_device::ClusterOp cluster) {
329 if (cluster.getNumResults() > 0)
330 return cluster->emitOpError(
331 "found nested tf_device.Cluster op with outputs. Nested cluster must "
332 "use send/recv instead.");
333
334 mlir::LogicalResult result = mlir::success();
335 mlir::visitUsedValuesDefinedAbove(
336 cluster.body(), cluster.body(), [&](mlir::OpOperand* input) {
337 if (!input->get().isa<mlir::BlockArgument>()) {
338 result = cluster.emitOpError(
339 "found nested tf_device.Cluster op with inputs. Nested cluster "
340 "must use send/recv instead.");
341 return;
342 }
343 });
344 return result;
345 }
346
347 // Returns whether `cluster` is inside then branch of `if_op`.
IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op,mlir::tf_device::ClusterOp cluster)348 bool IsInsideIfThenBranch(mlir::TF::IfRegionOp if_op,
349 mlir::tf_device::ClusterOp cluster) {
350 assert(if_op->isProperAncestor(cluster));
351 return if_op.then_branch().isAncestor(cluster->getParentRegion());
352 }
353
354 // Decomposes multi-mesh computation nested inside tf_if operations. See
355 // comments for `DecomposeControlflow()` function for details.
DecomposeIf(mlir::TF::IfRegionOp if_op,mlir::MLIRContext * context,int * num_control_flow_send_recvs)356 mlir::LogicalResult DecomposeIf(mlir::TF::IfRegionOp if_op,
357 mlir::MLIRContext* context,
358 int* num_control_flow_send_recvs) {
359 auto nested_clusters = FindAllDeviceClusters(if_op);
360 if (nested_clusters.empty()) return mlir::success();
361
362 for (mlir::tf_device::ClusterOp nested_cluster : nested_clusters) {
363 if (mlir::failed(VerifyClusterInputOutput(nested_cluster)))
364 return mlir::failure();
365
366 Mesh nested_mesh;
367 if (mlir::failed(ExtractMeshFromCluster(nested_cluster, &nested_mesh)))
368 return mlir::failure();
369
370 mlir::OpBuilder builder(if_op);
371 mlir::TF::IfRegionOp cloned_if;
372 CloneEmptyIfWithPredicate(if_op, nested_mesh, builder,
373 num_control_flow_send_recvs, context, &cloned_if);
374
375 // Find nested clusters in then/else branch of original `if_op` and
376 // move all inner ops inside nested cluster to `tf_cloned` in
377 // corresponding branch.
378 if (IsInsideIfThenBranch(if_op, nested_cluster)) {
379 mlir::Operation* then_branch_terminator =
380 cloned_if.then_branch().begin()->getTerminator();
381 auto& nested_cluster_operations =
382 nested_cluster.GetBody().getOperations();
383 cloned_if.then_branch().begin()->getOperations().splice(
384 then_branch_terminator->getIterator(), nested_cluster_operations,
385 nested_cluster_operations.begin(),
386 std::prev(nested_cluster_operations.end()));
387 } else {
388 mlir::Operation* else_branch_terminator =
389 cloned_if.else_branch().begin()->getTerminator();
390 auto& nested_cluster_operations =
391 nested_cluster.GetBody().getOperations();
392 cloned_if.else_branch().begin()->getOperations().splice(
393 else_branch_terminator->getIterator(), nested_cluster_operations,
394 nested_cluster_operations.begin(),
395 std::prev(nested_cluster_operations.end()));
396 }
397 nested_cluster.erase();
398 }
399 return mlir::success();
400 }
401
402 // Decomposes controlflows with nested mesh computations. When multi-mesh
403 // computation exists inside control flow operations like tf.If, then
404 // the control flow operations should be replicated to ensure correct execution
405 // semantics.
406 // For example:
407 //
408 // "tf_device.cluster"() ( {
409 // %1 = "tf.G"() : () -> (tensor<i1>)
410 // "tf.IfRegion"(%1) ({
411 // "tf_device.cluster"() ( {
412 // "tf.D"() {} : () -> ()
413 // tf_device.return
414 // }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
415 //
416 // "tf.Yield"() : () -> ()
417 // }, {
418 // }) {is_stateless = false} : (tensor<i1>) -> ()
419 // tf_device.return
420 // }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
421 //
422 // Above computation includes TPU device computation that exists inside
423 // tf.If op in CPU mesh. In this case, tf.If op should be replicated to TPU
424 // device computation so that `tf.D` op is executed in sync with CPU side
425 // computation. After transformation in this function, above IR is changed to:
426 //
427 // "tf_device.cluster"() ( {
428 // %1 = "tf.DTensorRecv"() : () -> tensor<i1>
429 // "tf.IfRegion"(%1) ( {
430 // "tf.D"() : () -> ()
431 // "tf.Yield"() : () -> ()
432 // }, {
433 // "tf.Yield"() : () -> ()
434 // }) {is_stateless = false} : (tensor<i1>) -> ()
435 // tf_device.return
436 // }) {_mesh = "TPU|x=1|0|0|TPU:0"} : () -> ()
437 //
438 // "tf_device.cluster"() ( {
439 // %1 = "tf.G"() : () -> tensor<i1>
440 // "tf.DTensorSend"(%1) : (tensor<i1>) -> ()
441 // "tf.IfRegion"(%1) ( {
442 // "tf.Yield"() : () -> ()
443 // }, {
444 // "tf.Yield"() : () -> ()
445 // }) {is_stateless = false} : (tensor<i1>) -> ()
446 // tf_device.return
447 // }) {_mesh = "CPU|x=1|0|0|CPU:0"} : () -> ()
448 //
449 // Note that:
450 // 1) Control flow is replicated.
451 // 2) DTensorSend/Recv ops are added to transfer predicate tensors for
452 // control flow operations
DecomposeControlflow(mlir::MLIRContext * context,int * num_control_flow_send_recvs,mlir::ModuleOp module)453 mlir::LogicalResult DecomposeControlflow(mlir::MLIRContext* context,
454 int* num_control_flow_send_recvs,
455 mlir::ModuleOp module) {
456 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
457 // Identify all clusters in topological order.
458 module.walk([&](mlir::tf_device::ClusterOp cluster) {
459 clusters.emplace_back(cluster);
460 });
461
462 for (mlir::tf_device::ClusterOp cluster : clusters) {
463 mlir::WalkResult walk_result = cluster->walk([&](mlir::Operation* op) {
464 if (auto if_op = mlir::dyn_cast<mlir::TF::IfRegionOp>(op)) {
465 if (mlir::failed(
466 DecomposeIf(if_op, context, num_control_flow_send_recvs)))
467 return mlir::WalkResult::interrupt();
468 }
469 return mlir::WalkResult::advance();
470 });
471 if (walk_result.wasInterrupted()) return mlir::failure();
472 }
473
474 return mlir::success();
475 }
476
477 // Merges multiple tf_device.clusters with same mesh specification to a single
478 // mesh cluster.
MergeClusters(mlir::ModuleOp module)479 mlir::LogicalResult MergeClusters(mlir::ModuleOp module) {
480 mlir::func::FuncOp main_func =
481 module.lookupSymbol<mlir::func::FuncOp>("main");
482
483 // Create global cluster for each mesh in entire computation.
484 auto clusters = FindAllDeviceClusters(main_func);
485 mlir::Block& func_block = *main_func.getBody().begin();
486 mlir::OpBuilder builder(&func_block.front());
487 std::map<Mesh, llvm::SmallVector<mlir::tf_device::ClusterOp, 4>> cluster_map;
488 std::vector<Mesh> meshes;
489 for (mlir::tf_device::ClusterOp cluster : clusters) {
490 Mesh mesh;
491 if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
492 return mlir::failure();
493
494 if (cluster_map.find(mesh) != cluster_map.end()) {
495 cluster_map[mesh].emplace_back(cluster);
496 } else {
497 cluster_map[mesh] =
498 llvm::SmallVector<mlir::tf_device::ClusterOp, 4>{cluster};
499 meshes.push_back(std::move(mesh));
500 }
501 }
502
503 // Reevaluate if this sort is necessary after b/186804270 is closed.
504 std::sort(meshes.begin(), meshes.end(), [](const Mesh& a, const Mesh& b) {
505 if (a.device_type() != b.device_type()) {
506 return a.device_type() < b.device_type();
507 }
508 return a < b;
509 });
510 for (const Mesh& mesh : meshes) {
511 const auto& mesh_cluster_list = cluster_map[mesh];
512 llvm::SmallVector<mlir::Value, 4> merged_cluster_outputs;
513 llvm::SmallVector<mlir::Value, 4> merged_return_values;
514 llvm::SmallVector<mlir::Type, 4> merged_return_types;
515
516 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
517 merged_cluster_outputs.insert(merged_cluster_outputs.end(),
518 cluster.results().begin(),
519 cluster.results().end());
520
521 auto return_values = cluster.GetBody().getTerminator()->getOperands();
522 merged_return_values.insert(merged_return_values.end(),
523 return_values.begin(), return_values.end());
524
525 auto return_type = cluster->getResultTypes();
526 merged_return_types.insert(merged_return_types.end(), return_type.begin(),
527 return_type.end());
528 }
529
530 // Create a single cluster op contains merged computations for `mesh`.
531 builder.setInsertionPoint(&func_block.front());
532 auto new_cluster = builder.create<mlir::tf_device::ClusterOp>(
533 module.getLoc(), merged_return_types);
534 new_cluster.body().push_back(new mlir::Block);
535 new_cluster->setAttr(kMeshAttr, builder.getStringAttr(mesh.ToString()));
536
537 // Move all ops inside clusters in cluster mesh to `new_cluster`.
538 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
539 mlir::Block& cluster_body = cluster.GetBody();
540 for (mlir::Operation& op_to_move :
541 llvm::make_early_inc_range(cluster_body.without_terminator())) {
542 for (mlir::OpOperand& use : op_to_move.getUses()) {
543 auto return_op =
544 llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
545 if (!return_op) continue;
546
547 mlir::Value output = cluster.getResult(use.getOperandNumber());
548 output.replaceUsesWithIf(use.get(), [](mlir::OpOperand& operand) {
549 return operand.getOwner()
550 ->getParentOfType<mlir::tf_device::ClusterOp>() !=
551 nullptr;
552 });
553 }
554 op_to_move.moveBefore(new_cluster.getBody(),
555 new_cluster.getBody()->end());
556 }
557 }
558
559 builder.setInsertionPointToEnd(&new_cluster.GetBody());
560 builder.create<mlir::tf_device::ReturnOp>(new_cluster.getLoc(),
561 merged_return_values);
562
563 // Replace return value usages.
564 for (auto it :
565 llvm::zip(merged_cluster_outputs, new_cluster.getResults())) {
566 mlir::Value value_to_replace = std::get<0>(it);
567 mlir::Value new_result_value = std::get<1>(it);
568 value_to_replace.replaceAllUsesWith(new_result_value);
569 }
570
571 // Erase clusters in cluster_map now that all ops are moved.
572 for (mlir::tf_device::ClusterOp cluster : mesh_cluster_list) {
573 if (mlir::failed(MergeClusterMetadata(cluster, new_cluster)))
574 return mlir::failure();
575
576 cluster.erase();
577 }
578 }
579
580 return mlir::success();
581 }
582
583 // Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
584 // into a single cluster. After this pass, exactly one tf_device.Cluster op
585 // exists for each device mesh.
586 struct DTensorMergeClusters
587 : public DTensorMergeClustersBase<DTensorMergeClusters> {
getDependentDialectstensorflow::dtensor::__anon216d88340111::DTensorMergeClusters588 void getDependentDialects(mlir::DialectRegistry& registry) const override {
589 registry.insert<mlir::dtensor::DTensorDialect>();
590 }
591
runOnOperationtensorflow::dtensor::__anon216d88340111::DTensorMergeClusters592 void runOnOperation() override {
593 mlir::MLIRContext& context = getContext();
594 mlir::OpBuilder op_builder(&context);
595 auto module = getOperation();
596 if (mlir::failed(InlineNestedDeviceClusters(module)))
597 return signalPassFailure();
598
599 int num_controlflow_send_recv = 0;
600 if (mlir::failed(
601 DecomposeControlflow(&context, &num_controlflow_send_recv, module)))
602 return signalPassFailure();
603
604 if (mlir::failed(MergeClusters(module))) return signalPassFailure();
605
606 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
607 module.walk([&](mlir::tf_device::ClusterOp cluster) {
608 clusters.emplace_back(cluster);
609 });
610
611 for (mlir::tf_device::ClusterOp cluster : clusters) {
612 RemoveUnusedClusterResults(cluster);
613 }
614 };
615 };
616
617 } // namespace
618
619 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorMergeClustersPass()620 CreateDTensorMergeClustersPass() {
621 return std::make_unique<DTensorMergeClusters>();
622 }
623
624 } // namespace dtensor
625 } // namespace tensorflow
626