1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/Support/Casting.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassManager.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/Passes.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34 #include "tensorflow/dtensor/cc/constants.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
37 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
38 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
39 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
40 #include "tensorflow/dtensor/mlir/layout_parsing.h"
41
42 namespace tensorflow {
43 namespace dtensor {
44
45 namespace {
46
47 // Extracts mesh config from the Op.
48 // We currently hard extract mesh information from all the args and assume they
49 // are the same. This should not be the case when we have multiple functions.
WrapDeviceCluster(mlir::OpBuilder * builder,mlir::Operation * op)50 mlir::LogicalResult WrapDeviceCluster(mlir::OpBuilder *builder,
51 mlir::Operation *op) {
52 // Create new tf_device.cluster op wrapping a single operation.
53 builder->setInsertionPoint(op);
54 auto cluster = builder->create<mlir::tf_device::ClusterOp>(
55 op->getLoc(), op->getResultTypes());
56 if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
57 cluster->setAttr(kMeshAttr, builder->getStringAttr(
58 layout_op.layout().mesh().ToString()));
59 } else if (auto copy_to_mesh = llvm::dyn_cast<mlir::TF::CopyToMeshOp>(op)) {
60 const std::string layout_string = copy_to_mesh.layout().str();
61 auto layout_or = Layout::FromString(layout_string);
62 if (!layout_or.ok())
63 return op->emitOpError(
64 llvm::formatv("Found tf.CopyToMesh Op with unparsable layout : {0}",
65 layout_string));
66
67 cluster->setAttr(kMeshAttr,
68 builder->getStringAttr(layout_or->mesh().ToString()));
69 } else {
70 // If mesh configuration can be inferred from the op directly, use the mesh
71 // information from op attribute directly. If op is not annotated with mesh
72 // information, then mesh will be inferred in following
73 // DTensorMeshPropagation pass and will be inferred from consumers or
74 // operands.
75 auto status_or_mesh = ExtractDeviceMeshFromOp(op);
76
77 if (!status_or_mesh.ok())
78 return op->emitOpError(
79 llvm::formatv("failed to wrap to device cluster. {0}",
80 status_or_mesh.status().error_message()));
81
82 const auto mesh_config = status_or_mesh.ValueOrDie();
83 if (mesh_config)
84 cluster->setAttr(kMeshAttr,
85 builder->getStringAttr(mesh_config->ToString()));
86 }
87
88 op->replaceAllUsesWith(cluster);
89
90 cluster.body().push_back(new mlir::Block);
91
92 builder->setInsertionPointToEnd(&cluster.GetBody());
93 builder->create<mlir::tf_device::ReturnOp>(op->getLoc(), op->getResults());
94
95 // Move `op` inside newly created `ClusterOp`.
96 op->moveBefore(cluster.GetBody().getTerminator());
97
98 return mlir::success();
99 }
100
101 // MLIR pass that wraps tf_device.cluster op to every TF op.
102 struct DTensorOpToDeviceClusterPass
103 : public DTensorOpToDeviceClusterBase<DTensorOpToDeviceClusterPass> {
getDependentDialectstensorflow::dtensor::__anonef5b7c610111::DTensorOpToDeviceClusterPass104 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
105 registry.insert<mlir::dtensor::DTensorDialect>();
106 registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
107 }
108
runOnOperationtensorflow::dtensor::__anonef5b7c610111::DTensorOpToDeviceClusterPass109 void runOnOperation() override {
110 mlir::MLIRContext &context = getContext();
111 mlir::OpBuilder op_builder(&context);
112 mlir::Dialect *tf =
113 getContext().getLoadedDialect<mlir::TF::TensorFlowDialect>();
114
115 auto walk_result = getOperation().walk([&](mlir::Operation *operation) {
116 const auto op_dialect = operation->getDialect();
117 // Only TF dialects are supported for layout propagation.
118 if (op_dialect != tf) return mlir::WalkResult::advance();
119
120 // For control flow operations, tf.yield ops exists and should not be
121 // wrapped to tf_device.cluster as the op does not need to be transformed
122 // in SPMD expansion and tf.If/tf.While op require all ops to terminate
123 // with tf.Yield op. Wrapping yield op in tf_device.cluster invalidates
124 // this invariant.
125 if (llvm::isa<mlir::TF::YieldOp>(operation))
126 return mlir::WalkResult::advance();
127
128 if (mlir::failed(WrapDeviceCluster(&op_builder, operation)))
129 return mlir::WalkResult::interrupt();
130 return mlir::WalkResult::advance();
131 });
132
133 if (walk_result.wasInterrupted()) signalPassFailure();
134 }
135 };
136
137 } // namespace
138
139 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorOpToDeviceClusterPass()140 CreateDTensorOpToDeviceClusterPass() {
141 return std::make_unique<DTensorOpToDeviceClusterPass>();
142 }
143
144 } // namespace dtensor
145 } // namespace tensorflow
146