xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/lower_send_recv.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
17 #include "mlir/IR/Builders.h"  // from @llvm-project
18 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
19 #include "mlir/IR/Operation.h"  // from @llvm-project
20 #include "mlir/IR/Types.h"  // from @llvm-project
21 #include "mlir/IR/Value.h"  // from @llvm-project
22 #include "mlir/IR/Visitors.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/dtensor/cc/constants.h"
31 #include "tensorflow/dtensor/mlir/device_utils.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
35 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36 #include "tensorflow/dtensor/mlir/layout_parsing.h"
37 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38 #include "tensorflow/dtensor/mlir/value_utils.h"
39 
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43 
44 constexpr char kMissingMeshErrorMsg[] =
45     "Failed to extract mesh for DTensorMergeCluster pass. "
46     "All clusters must have specified mesh.";
47 
48 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)49 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
50                                            Mesh* mesh_output) {
51   auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
52   if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
53 
54   const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
55   if (!mesh_or_null.has_value())
56     return cluster.emitOpError(kMissingMeshErrorMsg);
57 
58   *mesh_output = mesh_or_null.value();
59   return mlir::success();
60 }
61 
62 // Find all DTesorSend/Recv ops and lower into TF/XLA Send/Recv operations with
63 // execution kernels.
LowerDTensorSendRecvsOps(mlir::ModuleOp module)64 mlir::LogicalResult LowerDTensorSendRecvsOps(mlir::ModuleOp module) {
65   mlir::LogicalResult result = mlir::success();
66   module.walk([&](mlir::TF::DTensorSend send_op) {
67     if (mlir::failed(result)) return;
68 
69     auto recv_op = GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorSend>(
70         module, send_op);
71     if (!recv_op.ok()) {
72       result = send_op.emitOpError(recv_op.status().error_message());
73       return;
74     }
75     auto dtensor_recv = llvm::dyn_cast<mlir::TF::DTensorRecv>(*recv_op);
76     if (!dtensor_recv) {
77       result = send_op.emitOpError(
78           "Cannot find a matching DTensorRecv op for this DTensorSend op");
79       return;
80     }
81     const Mesh recv_mesh = dtensor_recv.layout().mesh();
82 
83     Mesh send_mesh;
84     if (mlir::failed(ExtractMeshFromCluster(
85             send_op->getParentOfType<mlir::tf_device::ClusterOp>(),
86             &send_mesh))) {
87       result = mlir::failure();
88       return;
89     }
90 
91     if (!send_mesh.is_tpu_mesh() && !recv_mesh.is_tpu_mesh()) {
92       result = send_op->emitOpError(
93           "Multi-mesh tensor transfer between non-xla devices are not yet "
94           "supported.");
95       return;
96     }
97 
98     const Layout recv_layout =
99         Layout::ReplicatedOnMesh(recv_mesh, ValueRank(dtensor_recv.output()));
100     const Layout send_input_layout =
101         Layout::ReplicatedOnMesh(send_mesh, ValueRank(send_op.input()));
102 
103     StatusOr<mlir::Operation*> lowered_recv =
104         LowerDTensorRecvToXlaOp(dtensor_recv);
105     if (!lowered_recv.ok()) {
106       result = dtensor_recv->emitOpError(lowered_recv.status().error_message());
107       return;
108     }
109     dtensor_recv->replaceAllUsesWith(*lowered_recv);
110     dtensor_recv.erase();
111 
112     auto lowered_send_or =
113         LowerDTensorSendToXlaOp(send_input_layout, send_op.input(), send_op,
114                                 /*from_spmd_expander=*/false);
115     if (!lowered_send_or.ok()) {
116       result = send_op->emitOpError(lowered_send_or.status().error_message());
117       return;
118     }
119   });
120   return result;
121 }
122 
123 // Adds Identity Op that uses device_id argument as inputs for clusters that
124 // does not have device id usages. When send/recv operations exists in
125 // tf_device.Clusters to transfer data across mesh clusters, device_id argument
126 // is required. However, mlir::func::FuncOp's created by transforming
127 // tf_device.Cluster to tf_device.ClusterFunc during ClusterOutlining pass will
128 // **not** include device_id as input argument if there are no usages within the
129 // cluster op body. As so, add Identity op that uses device_id argument from
130 // main function in all tf_device.Clusters so that device_id argument can be
131 // retained when converting tf_device.Cluster to functions.
PropagateDeviceIdToClusters(mlir::ModuleOp module)132 void PropagateDeviceIdToClusters(mlir::ModuleOp module) {
133   mlir::WalkResult result = module.walk([&](mlir::Operation* op) {
134     if (llvm::isa<mlir::TF::_XlaSendFromHostOp, mlir::TF::_XlaRecvAtHostV2Op,
135                   mlir::TF::XlaSendToHostOp, mlir::TF::XlaRecvFromHostOp,
136                   mlir::TF::_HostSendOp, mlir::TF::_HostRecvOp,
137                   mlir::TF::SendOp, mlir::TF::RecvOp>(op))
138       return mlir::WalkResult::interrupt();
139     return mlir::WalkResult::advance();
140   });
141 
142   const bool has_cross_mesh_send_recv = result.wasInterrupted();
143   if (!has_cross_mesh_send_recv) return;
144 
145   mlir::func::FuncOp main_func =
146       module.lookupSymbol<mlir::func::FuncOp>("main");
147   auto device_id = DeviceId(main_func);
148 
149   module.walk([&](mlir::tf_device::ClusterOp op) {
150     mlir::OpBuilder builder(&op.GetBody().front());
151     builder.create<mlir::TF::IdentityOp>(main_func.getLoc(),
152                                          device_id->getType(), *device_id);
153   });
154 }
155 
156 // Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
157 // into a single cluster. After this pass, exactly one tf_device.Cluster op
158 // exists for each device mesh.
159 struct DTensorLowerSendRecv
160     : public DTensorLowerSendRecvBase<DTensorLowerSendRecv> {
runOnOperationtensorflow::dtensor::__anonf861fb510111::DTensorLowerSendRecv161   void runOnOperation() override {
162     mlir::MLIRContext& context = getContext();
163     mlir::OpBuilder op_builder(&context);
164     auto module = getOperation();
165 
166     // Merging clusters and decomposing control flow may have created new
167     // DTensorSend/DTensorRecv ops. Lower DTensorSend/DTensorRecv ops added by
168     // above transformations.
169     if (mlir::failed(LowerDTensorSendRecvsOps(module)))
170       return signalPassFailure();
171 
172     // Ensure that all mesh clusters have at least one usages of device_id
173     // argument from main function to guarantee that device_id argument is
174     // retained after ClusterOutlinging.
175     PropagateDeviceIdToClusters(module);
176   };
177 };
178 
179 }  // namespace
180 
181 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorLowerSendRecv()182 CreateDTensorLowerSendRecv() {
183   return std::make_unique<DTensorLowerSendRecv>();
184 }
185 
186 }  // namespace dtensor
187 }  // namespace tensorflow
188