1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/IR/Builders.h" // from @llvm-project
18 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
19 #include "mlir/IR/Operation.h" // from @llvm-project
20 #include "mlir/IR/Types.h" // from @llvm-project
21 #include "mlir/IR/Value.h" // from @llvm-project
22 #include "mlir/IR/Visitors.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassManager.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "mlir/Transforms/Passes.h" // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/dtensor/cc/constants.h"
31 #include "tensorflow/dtensor/mlir/device_utils.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
35 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36 #include "tensorflow/dtensor/mlir/layout_parsing.h"
37 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38 #include "tensorflow/dtensor/mlir/value_utils.h"
39
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43
44 constexpr char kMissingMeshErrorMsg[] =
45 "Failed to extract mesh for DTensorMergeCluster pass. "
46 "All clusters must have specified mesh.";
47
48 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)49 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
50 Mesh* mesh_output) {
51 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
52 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
53
54 const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
55 if (!mesh_or_null.has_value())
56 return cluster.emitOpError(kMissingMeshErrorMsg);
57
58 *mesh_output = mesh_or_null.value();
59 return mlir::success();
60 }
61
62 // Find all DTesorSend/Recv ops and lower into TF/XLA Send/Recv operations with
63 // execution kernels.
LowerDTensorSendRecvsOps(mlir::ModuleOp module)64 mlir::LogicalResult LowerDTensorSendRecvsOps(mlir::ModuleOp module) {
65 mlir::LogicalResult result = mlir::success();
66 module.walk([&](mlir::TF::DTensorSend send_op) {
67 if (mlir::failed(result)) return;
68
69 auto recv_op = GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorSend>(
70 module, send_op);
71 if (!recv_op.ok()) {
72 result = send_op.emitOpError(recv_op.status().error_message());
73 return;
74 }
75 auto dtensor_recv = llvm::dyn_cast<mlir::TF::DTensorRecv>(*recv_op);
76 if (!dtensor_recv) {
77 result = send_op.emitOpError(
78 "Cannot find a matching DTensorRecv op for this DTensorSend op");
79 return;
80 }
81 const Mesh recv_mesh = dtensor_recv.layout().mesh();
82
83 Mesh send_mesh;
84 if (mlir::failed(ExtractMeshFromCluster(
85 send_op->getParentOfType<mlir::tf_device::ClusterOp>(),
86 &send_mesh))) {
87 result = mlir::failure();
88 return;
89 }
90
91 if (!send_mesh.is_tpu_mesh() && !recv_mesh.is_tpu_mesh()) {
92 result = send_op->emitOpError(
93 "Multi-mesh tensor transfer between non-xla devices are not yet "
94 "supported.");
95 return;
96 }
97
98 const Layout recv_layout =
99 Layout::ReplicatedOnMesh(recv_mesh, ValueRank(dtensor_recv.output()));
100 const Layout send_input_layout =
101 Layout::ReplicatedOnMesh(send_mesh, ValueRank(send_op.input()));
102
103 StatusOr<mlir::Operation*> lowered_recv =
104 LowerDTensorRecvToXlaOp(dtensor_recv);
105 if (!lowered_recv.ok()) {
106 result = dtensor_recv->emitOpError(lowered_recv.status().error_message());
107 return;
108 }
109 dtensor_recv->replaceAllUsesWith(*lowered_recv);
110 dtensor_recv.erase();
111
112 auto lowered_send_or =
113 LowerDTensorSendToXlaOp(send_input_layout, send_op.input(), send_op,
114 /*from_spmd_expander=*/false);
115 if (!lowered_send_or.ok()) {
116 result = send_op->emitOpError(lowered_send_or.status().error_message());
117 return;
118 }
119 });
120 return result;
121 }
122
123 // Adds Identity Op that uses device_id argument as inputs for clusters that
124 // does not have device id usages. When send/recv operations exists in
125 // tf_device.Clusters to transfer data across mesh clusters, device_id argument
126 // is required. However, mlir::func::FuncOp's created by transforming
127 // tf_device.Cluster to tf_device.ClusterFunc during ClusterOutlining pass will
128 // **not** include device_id as input argument if there are no usages within the
129 // cluster op body. As so, add Identity op that uses device_id argument from
130 // main function in all tf_device.Clusters so that device_id argument can be
131 // retained when converting tf_device.Cluster to functions.
PropagateDeviceIdToClusters(mlir::ModuleOp module)132 void PropagateDeviceIdToClusters(mlir::ModuleOp module) {
133 mlir::WalkResult result = module.walk([&](mlir::Operation* op) {
134 if (llvm::isa<mlir::TF::_XlaSendFromHostOp, mlir::TF::_XlaRecvAtHostV2Op,
135 mlir::TF::XlaSendToHostOp, mlir::TF::XlaRecvFromHostOp,
136 mlir::TF::_HostSendOp, mlir::TF::_HostRecvOp,
137 mlir::TF::SendOp, mlir::TF::RecvOp>(op))
138 return mlir::WalkResult::interrupt();
139 return mlir::WalkResult::advance();
140 });
141
142 const bool has_cross_mesh_send_recv = result.wasInterrupted();
143 if (!has_cross_mesh_send_recv) return;
144
145 mlir::func::FuncOp main_func =
146 module.lookupSymbol<mlir::func::FuncOp>("main");
147 auto device_id = DeviceId(main_func);
148
149 module.walk([&](mlir::tf_device::ClusterOp op) {
150 mlir::OpBuilder builder(&op.GetBody().front());
151 builder.create<mlir::TF::IdentityOp>(main_func.getLoc(),
152 device_id->getType(), *device_id);
153 });
154 }
155
156 // Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
157 // into a single cluster. After this pass, exactly one tf_device.Cluster op
158 // exists for each device mesh.
159 struct DTensorLowerSendRecv
160 : public DTensorLowerSendRecvBase<DTensorLowerSendRecv> {
runOnOperationtensorflow::dtensor::__anonf861fb510111::DTensorLowerSendRecv161 void runOnOperation() override {
162 mlir::MLIRContext& context = getContext();
163 mlir::OpBuilder op_builder(&context);
164 auto module = getOperation();
165
166 // Merging clusters and decomposing control flow may have created new
167 // DTensorSend/DTensorRecv ops. Lower DTensorSend/DTensorRecv ops added by
168 // above transformations.
169 if (mlir::failed(LowerDTensorSendRecvsOps(module)))
170 return signalPassFailure();
171
172 // Ensure that all mesh clusters have at least one usages of device_id
173 // argument from main function to guarantee that device_id argument is
174 // retained after ClusterOutlinging.
175 PropagateDeviceIdToClusters(module);
176 };
177 };
178
179 } // namespace
180
181 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorLowerSendRecv()182 CreateDTensorLowerSendRecv() {
183 return std::make_unique<DTensorLowerSendRecv>();
184 }
185
186 } // namespace dtensor
187 } // namespace tensorflow
188