xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/IR/Block.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/Passes.h"  // from @llvm-project
30 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/dtensor/cc/constants.h"
34 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
35 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
36 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
39 
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43 
44 constexpr char kMissingMeshErrorMsg[] =
45     "Failed to extract mesh for DTensorHandleCrossClusterDependencies pass. "
46     "All clusters must have specified mesh.";
47 
48 constexpr char kInvalidTensorTransferErrorMsg[] =
49     "CopyToMeshOp must be used to send data across mesh.";
50 
51 constexpr char kInvalidLayoutMsg[] =
52     "found CopyToMesh with invalid layout. Found layout {0}.";
53 
54 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)55 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
56                                            Mesh* mesh_output) {
57   auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
58   if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
59 
60   const auto& mesh_or_null = mesh_or_status.ValueOrDie();
61   if (!mesh_or_null.has_value())
62     return cluster.emitOpError(kMissingMeshErrorMsg);
63 
64   *mesh_output = mesh_or_null.value();
65   return mlir::success();
66 }
67 
68 // Returns const op if `op` is a const op or DTensorLayoutOp with Const op as
69 // input.
GetConstOp(mlir::Operation * op)70 mlir::Operation* GetConstOp(mlir::Operation* op) {
71   if (llvm::isa<mlir::TF::ConstOp>(op)) return op;
72 
73   if (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
74     mlir::Operation* input_op = layout.input().getDefiningOp();
75     if (input_op && llvm::isa<mlir::TF::ConstOp>(input_op)) return input_op;
76   }
77   return nullptr;
78 }
79 
80 // Creates a clone of `const_op` at the beginning of `cluster` body region and
81 // set the output value of cloned op replace output of CopyToMesh op within
82 // `cluster`.
CloneOpToCluster(mlir::Operation * const_op,mlir::tf_device::ClusterOp cluster,mlir::OpOperand * operand)83 mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op,
84                                      mlir::tf_device::ClusterOp cluster,
85                                      mlir::OpOperand* operand) {
86   auto copy_to_mesh =
87       llvm::dyn_cast<mlir::TF::CopyToMeshOp>(operand->getOwner());
88   assert(copy_to_mesh);
89   const std::string layout_attr = copy_to_mesh.layout().str();
90   StatusOr<Layout> layout = Layout::FromString(layout_attr);
91   if (!layout.ok())
92     return copy_to_mesh.emitOpError(
93         llvm::formatv(kInvalidLayoutMsg, layout_attr));
94 
95   mlir::OpBuilder builder(&cluster.GetBody().front());
96   mlir::Operation* cloned_op = builder.clone(*const_op);
97   mlir::TensorType type =
98       cloned_op->getResult(0).getType().cast<mlir::TensorType>();
99   auto layout_op = builder.create<mlir::TF::DTensorLayout>(
100       const_op->getLoc(), cloned_op->getResult(0),
101       mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout),
102       mlir::TF::ShapeAttr::get(builder.getContext(), type));
103 
104   copy_to_mesh.output().replaceUsesWithIf(
105       layout_op.output(), [&](mlir::OpOperand& operand) {
106         return cluster.getOperation()->isProperAncestor(operand.getOwner());
107       });
108 
109   if (copy_to_mesh->getUsers().empty()) copy_to_mesh.erase();
110 
111   return mlir::success();
112 }
113 
GetInputProducingValue(mlir::OpOperand & operand,mlir::Value * val_output)114 mlir::LogicalResult GetInputProducingValue(mlir::OpOperand& operand,
115                                            mlir::Value* val_output) {
116   auto input_value = operand.get().dyn_cast<mlir::OpResult>();
117   if (!input_value) return mlir::success();
118 
119   auto input_cluster =
120       llvm::dyn_cast<mlir::tf_device::ClusterOp>(input_value.getOwner());
121   if (input_cluster) {
122     // If value is from another tf_device.cluster output, then query into
123     // the terminator of the input cluster to get mlir::Value from Tensorflow
124     // operation that is producing the value.
125     *val_output = input_cluster.GetBody().getTerminator()->getOperand(
126         input_value.getResultNumber());
127   } else {
128     *val_output = input_value;
129   }
130   return mlir::success();
131 }
132 
133 // Copies constant operation to mesh clusters if there are multiple usages of
134 // constants across multiple mesh computations. This is needed for 2 reasons.
135 // a) Cloning constants across mesh can reduce send/recvs during execution.
136 // b) DTensor SPMD Expansion for some ops (like tf.reduce_sum) requires inputs
137 //    to computation to be constants.
CloneConstantsAcrossMesh(mlir::tf_device::ClusterOp cluster)138 mlir::LogicalResult CloneConstantsAcrossMesh(
139     mlir::tf_device::ClusterOp cluster) {
140   auto& body_region = cluster.body();
141   Mesh mesh;
142   if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
143     return mlir::failure();
144 
145   mlir::LogicalResult result(mlir::success());
146   mlir::visitUsedValuesDefinedAbove(
147       body_region, body_region, [&](mlir::OpOperand* operand) {
148         if (mlir::failed(result)) return;
149 
150         mlir::Value input_value;
151         result = GetInputProducingValue(*operand, &input_value);
152         if (mlir::failed(result) || !input_value) return;
153 
154         auto input_cluster =
155             input_value.getDefiningOp()
156                 ->getParentOfType<mlir::tf_device::ClusterOp>();
157         Mesh input_mesh;
158         if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
159           result = mlir::failure();
160           return;
161         }
162 
163         if (input_mesh == mesh) return;
164         if (!llvm::isa<mlir::TF::CopyToMeshOp>(operand->getOwner())) {
165           result =
166               operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
167           return;
168         }
169 
170         mlir::Operation* const_op = GetConstOp(input_value.getDefiningOp());
171         if (const_op) result = CloneOpToCluster(const_op, cluster, operand);
172       });
173 
174   return result;
175 }
176 
177 // Transforms CopyToMesh op to a pair of DTensorSend/DTensorRecv operations.
LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh,mlir::MLIRContext * context,int * send_recv_counter)178 mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh,
179                                     mlir::MLIRContext* context,
180                                     int* send_recv_counter) {
181   const mlir::OpResult copied_value =
182       copy_to_mesh.input().cast<mlir::OpResult>();
183   const int result_index = copied_value.getResultNumber();
184   auto src_cluster =
185       llvm::cast<mlir::tf_device::ClusterOp>(copied_value.getDefiningOp());
186   mlir::Value value_to_send =
187       src_cluster.GetBody().getTerminator()->getOperand(result_index);
188 
189   // Create DTensorSend op that sends `value_to_send` across mesh cluster.
190   mlir::OpBuilder builder(value_to_send.getParentBlock()->getTerminator());
191 
192   const std::string op_key =
193       llvm::formatv("communication_key_{0}_{1}", copy_to_mesh.layout(),
194                     *send_recv_counter)
195           .str();
196   const std::string layout_attr = copy_to_mesh.layout().str();
197   auto layout_or_status = Layout::FromString(layout_attr);
198   if (!layout_or_status.ok())
199     return copy_to_mesh.emitOpError(
200         llvm::formatv(kInvalidLayoutMsg, layout_attr));
201 
202   // Create send op that sends data from input cluster to target cluster.
203   const Layout& target_layout = layout_or_status.ValueOrDie();
204   builder.create<mlir::TF::DTensorSend>(
205       copy_to_mesh.getLoc(), value_to_send, builder.getStringAttr(op_key),
206       mlir::dtensor::LayoutAttr::get(context, target_layout));
207 
208   // Create recv op that recvs data from send op.
209   auto tensor_type = value_to_send.getType().dyn_cast<mlir::TensorType>();
210   if (!tensor_type)
211     return copy_to_mesh.emitOpError(
212         "found CopyToMesh sending value with unknown shape. Inputs to "
213         "CopyToMesh op must have static shape.");
214 
215   builder.setInsertionPoint(copy_to_mesh);
216   auto recv_op = builder.create<mlir::TF::DTensorRecv>(
217       copy_to_mesh.getLoc(), value_to_send.getType(),
218       builder.getStringAttr(op_key),
219       mlir::TF::ShapeAttr::get(context, tensor_type),
220       mlir::dtensor::LayoutAttr::get(context, target_layout));
221 
222   // Replace value for recv ops for all usages of `copy_to_mesh` op.
223   copy_to_mesh.replaceAllUsesWith(recv_op.output());
224 
225   // Remove copy to mesh op.
226   copy_to_mesh.erase();
227 
228   *send_recv_counter += 1;
229 
230   return mlir::success();
231 }
232 
233 // Lowers tf.CopyToMesh to a pair of DTensorSend/DTensorRecv operations.
234 //
235 // For example:
236 //    %0 = "tf_device.cluster"() ({
237 //      %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
238 //      tf_device.return %1 : tensor<i32>
239 //    }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
240 //
241 //    %2 = "tf_device.cluster"() ({
242 //      %3 = "tf.CopyToMesh"(%0)
243 //          { layout ="mesh:TPU,x=2,y=2 layout:x,replicated" } :
244 //                      (tensor<i32>) -> (tensor<i32>)
245 //      %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
246 //      tf_device.return %4 : tensor<i32>
247 //    }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
248 //    return
249 // }
250 //
251 // Is transformed to:
252 //
253 //    %0 = "tf_device.cluster"() ({
254 //      %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
255 //      "tf.DTensorSend"(%1) {...} : (tensor<i32>) -> ()
256 //      tf_device.return %1 : tensor<i32>
257 //    }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
258 //
259 //    %2 = "tf_device.cluster"() ({
260 //      %3 = "tf.DTensorRecv"() {...} : () -> (tensor<i32>)
261 //      %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
262 //      tf_device.return %4 : tensor<i32>
263 //    }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
264 //    return
265 // }
ReplaceCopyToMeshWithVirtualSendRecv(mlir::tf_device::ClusterOp cluster,mlir::MLIRContext * context,int * send_recv_counter)266 mlir::LogicalResult ReplaceCopyToMeshWithVirtualSendRecv(
267     mlir::tf_device::ClusterOp cluster, mlir::MLIRContext* context,
268     int* send_recv_counter) {
269   Mesh current_mesh;
270   if (mlir::failed(ExtractMeshFromCluster(cluster, &current_mesh)))
271     return mlir::failure();
272 
273   mlir::Region& cluster_region = cluster.body();
274   mlir::LogicalResult result = mlir::success();
275 
276   mlir::visitUsedValuesDefinedAbove(
277       cluster_region, cluster_region, [&](mlir::OpOperand* operand) {
278         mlir::Value input_value;
279         if (mlir::failed(GetInputProducingValue(*operand, &input_value))) {
280           result = mlir::failure();
281           return;
282         }
283         if (!input_value) return;
284 
285         auto input_cluster =
286             input_value.getDefiningOp()
287                 ->getParentOfType<mlir::tf_device::ClusterOp>();
288         Mesh input_mesh;
289         if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
290           result = mlir::failure();
291           return;
292         }
293 
294         if (current_mesh == input_mesh) return;
295 
296         // Check that values that cross mesh boundaries go through CopyToMesh
297         // op.
298         mlir::Operation* input_op = operand->getOwner();
299         mlir::TF::CopyToMeshOp copy_to_mesh =
300             llvm::dyn_cast<mlir::TF::CopyToMeshOp>(input_op);
301         if (!copy_to_mesh) {
302           result =
303               operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
304           return;
305         }
306 
307         // Lower CopyToMesh op to a pair of virtual Send/Recv op.
308         if (mlir::failed(
309                 LowerToSendRecv(copy_to_mesh, context, send_recv_counter))) {
310           result = mlir::failure();
311           return;
312         }
313       });
314   return result;
315 }
316 
317 struct DTensorHandleCrossClusterDependencies
318     : public DTensorHandleCrossClusterDependenciesBase<
319           DTensorHandleCrossClusterDependencies> {
getDependentDialectstensorflow::dtensor::__anon136dc6ec0111::DTensorHandleCrossClusterDependencies320   void getDependentDialects(mlir::DialectRegistry& registry) const override {
321     registry.insert<mlir::dtensor::DTensorDialect>();
322   }
323 
runOnOperationtensorflow::dtensor::__anon136dc6ec0111::DTensorHandleCrossClusterDependencies324   void runOnOperation() override {
325     mlir::MLIRContext& context = getContext();
326     mlir::ModuleOp module = getOperation();
327     llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
328     module.walk([&](mlir::tf_device::ClusterOp cluster) {
329       clusters.emplace_back(cluster);
330     });
331 
332     int send_recv_counter = 0;
333     for (auto cluster : clusters) {
334       if (mlir::failed(CloneConstantsAcrossMesh(cluster)))
335         return signalPassFailure();
336 
337       if (mlir::failed(ReplaceCopyToMeshWithVirtualSendRecv(
338               cluster, &context, &send_recv_counter)))
339         return signalPassFailure();
340     }
341 
342     // Once CopyToMesh has been lowered to DTensorSend/Recv operations,
343     // tf_device.Cluster may now have dangling/unused result values. Remove all
344     // such return values.
345     for (auto cluster : clusters) RemoveUnusedClusterResults(cluster);
346   }
347 };
348 
349 }  // namespace
350 
351 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorHandleCrossClusterDependencies()352 CreateDTensorHandleCrossClusterDependencies() {
353   return std::make_unique<DTensorHandleCrossClusterDependencies>();
354 }
355 
356 }  // namespace dtensor
357 }  // namespace tensorflow
358