1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/IR/Block.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/UseDefLists.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/Passes.h" // from @llvm-project
30 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/dtensor/cc/constants.h"
34 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
35 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
36 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
39
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43
44 constexpr char kMissingMeshErrorMsg[] =
45 "Failed to extract mesh for DTensorHandleCrossClusterDependencies pass. "
46 "All clusters must have specified mesh.";
47
48 constexpr char kInvalidTensorTransferErrorMsg[] =
49 "CopyToMeshOp must be used to send data across mesh.";
50
51 constexpr char kInvalidLayoutMsg[] =
52 "found CopyToMesh with invalid layout. Found layout {0}.";
53
54 // Extracts mesh from `cluster`.
ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,Mesh * mesh_output)55 mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
56 Mesh* mesh_output) {
57 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
58 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
59
60 const auto& mesh_or_null = mesh_or_status.ValueOrDie();
61 if (!mesh_or_null.has_value())
62 return cluster.emitOpError(kMissingMeshErrorMsg);
63
64 *mesh_output = mesh_or_null.value();
65 return mlir::success();
66 }
67
68 // Returns const op if `op` is a const op or DTensorLayoutOp with Const op as
69 // input.
GetConstOp(mlir::Operation * op)70 mlir::Operation* GetConstOp(mlir::Operation* op) {
71 if (llvm::isa<mlir::TF::ConstOp>(op)) return op;
72
73 if (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
74 mlir::Operation* input_op = layout.input().getDefiningOp();
75 if (input_op && llvm::isa<mlir::TF::ConstOp>(input_op)) return input_op;
76 }
77 return nullptr;
78 }
79
80 // Creates a clone of `const_op` at the beginning of `cluster` body region and
81 // set the output value of cloned op replace output of CopyToMesh op within
82 // `cluster`.
CloneOpToCluster(mlir::Operation * const_op,mlir::tf_device::ClusterOp cluster,mlir::OpOperand * operand)83 mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op,
84 mlir::tf_device::ClusterOp cluster,
85 mlir::OpOperand* operand) {
86 auto copy_to_mesh =
87 llvm::dyn_cast<mlir::TF::CopyToMeshOp>(operand->getOwner());
88 assert(copy_to_mesh);
89 const std::string layout_attr = copy_to_mesh.layout().str();
90 StatusOr<Layout> layout = Layout::FromString(layout_attr);
91 if (!layout.ok())
92 return copy_to_mesh.emitOpError(
93 llvm::formatv(kInvalidLayoutMsg, layout_attr));
94
95 mlir::OpBuilder builder(&cluster.GetBody().front());
96 mlir::Operation* cloned_op = builder.clone(*const_op);
97 mlir::TensorType type =
98 cloned_op->getResult(0).getType().cast<mlir::TensorType>();
99 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
100 const_op->getLoc(), cloned_op->getResult(0),
101 mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout),
102 mlir::TF::ShapeAttr::get(builder.getContext(), type));
103
104 copy_to_mesh.output().replaceUsesWithIf(
105 layout_op.output(), [&](mlir::OpOperand& operand) {
106 return cluster.getOperation()->isProperAncestor(operand.getOwner());
107 });
108
109 if (copy_to_mesh->getUsers().empty()) copy_to_mesh.erase();
110
111 return mlir::success();
112 }
113
GetInputProducingValue(mlir::OpOperand & operand,mlir::Value * val_output)114 mlir::LogicalResult GetInputProducingValue(mlir::OpOperand& operand,
115 mlir::Value* val_output) {
116 auto input_value = operand.get().dyn_cast<mlir::OpResult>();
117 if (!input_value) return mlir::success();
118
119 auto input_cluster =
120 llvm::dyn_cast<mlir::tf_device::ClusterOp>(input_value.getOwner());
121 if (input_cluster) {
122 // If value is from another tf_device.cluster output, then query into
123 // the terminator of the input cluster to get mlir::Value from Tensorflow
124 // operation that is producing the value.
125 *val_output = input_cluster.GetBody().getTerminator()->getOperand(
126 input_value.getResultNumber());
127 } else {
128 *val_output = input_value;
129 }
130 return mlir::success();
131 }
132
133 // Copies constant operation to mesh clusters if there are multiple usages of
134 // constants across multiple mesh computations. This is needed for 2 reasons.
135 // a) Cloning constants across mesh can reduce send/recvs during execution.
136 // b) DTensor SPMD Expansion for some ops (like tf.reduce_sum) requires inputs
137 // to computation to be constants.
CloneConstantsAcrossMesh(mlir::tf_device::ClusterOp cluster)138 mlir::LogicalResult CloneConstantsAcrossMesh(
139 mlir::tf_device::ClusterOp cluster) {
140 auto& body_region = cluster.body();
141 Mesh mesh;
142 if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
143 return mlir::failure();
144
145 mlir::LogicalResult result(mlir::success());
146 mlir::visitUsedValuesDefinedAbove(
147 body_region, body_region, [&](mlir::OpOperand* operand) {
148 if (mlir::failed(result)) return;
149
150 mlir::Value input_value;
151 result = GetInputProducingValue(*operand, &input_value);
152 if (mlir::failed(result) || !input_value) return;
153
154 auto input_cluster =
155 input_value.getDefiningOp()
156 ->getParentOfType<mlir::tf_device::ClusterOp>();
157 Mesh input_mesh;
158 if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
159 result = mlir::failure();
160 return;
161 }
162
163 if (input_mesh == mesh) return;
164 if (!llvm::isa<mlir::TF::CopyToMeshOp>(operand->getOwner())) {
165 result =
166 operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
167 return;
168 }
169
170 mlir::Operation* const_op = GetConstOp(input_value.getDefiningOp());
171 if (const_op) result = CloneOpToCluster(const_op, cluster, operand);
172 });
173
174 return result;
175 }
176
177 // Transforms CopyToMesh op to a pair of DTensorSend/DTensorRecv operations.
LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh,mlir::MLIRContext * context,int * send_recv_counter)178 mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh,
179 mlir::MLIRContext* context,
180 int* send_recv_counter) {
181 const mlir::OpResult copied_value =
182 copy_to_mesh.input().cast<mlir::OpResult>();
183 const int result_index = copied_value.getResultNumber();
184 auto src_cluster =
185 llvm::cast<mlir::tf_device::ClusterOp>(copied_value.getDefiningOp());
186 mlir::Value value_to_send =
187 src_cluster.GetBody().getTerminator()->getOperand(result_index);
188
189 // Create DTensorSend op that sends `value_to_send` across mesh cluster.
190 mlir::OpBuilder builder(value_to_send.getParentBlock()->getTerminator());
191
192 const std::string op_key =
193 llvm::formatv("communication_key_{0}_{1}", copy_to_mesh.layout(),
194 *send_recv_counter)
195 .str();
196 const std::string layout_attr = copy_to_mesh.layout().str();
197 auto layout_or_status = Layout::FromString(layout_attr);
198 if (!layout_or_status.ok())
199 return copy_to_mesh.emitOpError(
200 llvm::formatv(kInvalidLayoutMsg, layout_attr));
201
202 // Create send op that sends data from input cluster to target cluster.
203 const Layout& target_layout = layout_or_status.ValueOrDie();
204 builder.create<mlir::TF::DTensorSend>(
205 copy_to_mesh.getLoc(), value_to_send, builder.getStringAttr(op_key),
206 mlir::dtensor::LayoutAttr::get(context, target_layout));
207
208 // Create recv op that recvs data from send op.
209 auto tensor_type = value_to_send.getType().dyn_cast<mlir::TensorType>();
210 if (!tensor_type)
211 return copy_to_mesh.emitOpError(
212 "found CopyToMesh sending value with unknown shape. Inputs to "
213 "CopyToMesh op must have static shape.");
214
215 builder.setInsertionPoint(copy_to_mesh);
216 auto recv_op = builder.create<mlir::TF::DTensorRecv>(
217 copy_to_mesh.getLoc(), value_to_send.getType(),
218 builder.getStringAttr(op_key),
219 mlir::TF::ShapeAttr::get(context, tensor_type),
220 mlir::dtensor::LayoutAttr::get(context, target_layout));
221
222 // Replace value for recv ops for all usages of `copy_to_mesh` op.
223 copy_to_mesh.replaceAllUsesWith(recv_op.output());
224
225 // Remove copy to mesh op.
226 copy_to_mesh.erase();
227
228 *send_recv_counter += 1;
229
230 return mlir::success();
231 }
232
233 // Lowers tf.CopyToMesh to a pair of DTensorSend/DTensorRecv operations.
234 //
235 // For example:
236 // %0 = "tf_device.cluster"() ({
237 // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
238 // tf_device.return %1 : tensor<i32>
239 // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
240 //
241 // %2 = "tf_device.cluster"() ({
242 // %3 = "tf.CopyToMesh"(%0)
243 // { layout ="mesh:TPU,x=2,y=2 layout:x,replicated" } :
244 // (tensor<i32>) -> (tensor<i32>)
245 // %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
246 // tf_device.return %4 : tensor<i32>
247 // }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
248 // return
249 // }
250 //
251 // Is transformed to:
252 //
253 // %0 = "tf_device.cluster"() ({
254 // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
255 // "tf.DTensorSend"(%1) {...} : (tensor<i32>) -> ()
256 // tf_device.return %1 : tensor<i32>
257 // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
258 //
259 // %2 = "tf_device.cluster"() ({
260 // %3 = "tf.DTensorRecv"() {...} : () -> (tensor<i32>)
261 // %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
262 // tf_device.return %4 : tensor<i32>
263 // }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
264 // return
265 // }
ReplaceCopyToMeshWithVirtualSendRecv(mlir::tf_device::ClusterOp cluster,mlir::MLIRContext * context,int * send_recv_counter)266 mlir::LogicalResult ReplaceCopyToMeshWithVirtualSendRecv(
267 mlir::tf_device::ClusterOp cluster, mlir::MLIRContext* context,
268 int* send_recv_counter) {
269 Mesh current_mesh;
270 if (mlir::failed(ExtractMeshFromCluster(cluster, ¤t_mesh)))
271 return mlir::failure();
272
273 mlir::Region& cluster_region = cluster.body();
274 mlir::LogicalResult result = mlir::success();
275
276 mlir::visitUsedValuesDefinedAbove(
277 cluster_region, cluster_region, [&](mlir::OpOperand* operand) {
278 mlir::Value input_value;
279 if (mlir::failed(GetInputProducingValue(*operand, &input_value))) {
280 result = mlir::failure();
281 return;
282 }
283 if (!input_value) return;
284
285 auto input_cluster =
286 input_value.getDefiningOp()
287 ->getParentOfType<mlir::tf_device::ClusterOp>();
288 Mesh input_mesh;
289 if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
290 result = mlir::failure();
291 return;
292 }
293
294 if (current_mesh == input_mesh) return;
295
296 // Check that values that cross mesh boundaries go through CopyToMesh
297 // op.
298 mlir::Operation* input_op = operand->getOwner();
299 mlir::TF::CopyToMeshOp copy_to_mesh =
300 llvm::dyn_cast<mlir::TF::CopyToMeshOp>(input_op);
301 if (!copy_to_mesh) {
302 result =
303 operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
304 return;
305 }
306
307 // Lower CopyToMesh op to a pair of virtual Send/Recv op.
308 if (mlir::failed(
309 LowerToSendRecv(copy_to_mesh, context, send_recv_counter))) {
310 result = mlir::failure();
311 return;
312 }
313 });
314 return result;
315 }
316
317 struct DTensorHandleCrossClusterDependencies
318 : public DTensorHandleCrossClusterDependenciesBase<
319 DTensorHandleCrossClusterDependencies> {
getDependentDialectstensorflow::dtensor::__anon136dc6ec0111::DTensorHandleCrossClusterDependencies320 void getDependentDialects(mlir::DialectRegistry& registry) const override {
321 registry.insert<mlir::dtensor::DTensorDialect>();
322 }
323
runOnOperationtensorflow::dtensor::__anon136dc6ec0111::DTensorHandleCrossClusterDependencies324 void runOnOperation() override {
325 mlir::MLIRContext& context = getContext();
326 mlir::ModuleOp module = getOperation();
327 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
328 module.walk([&](mlir::tf_device::ClusterOp cluster) {
329 clusters.emplace_back(cluster);
330 });
331
332 int send_recv_counter = 0;
333 for (auto cluster : clusters) {
334 if (mlir::failed(CloneConstantsAcrossMesh(cluster)))
335 return signalPassFailure();
336
337 if (mlir::failed(ReplaceCopyToMeshWithVirtualSendRecv(
338 cluster, &context, &send_recv_counter)))
339 return signalPassFailure();
340 }
341
342 // Once CopyToMesh has been lowered to DTensorSend/Recv operations,
343 // tf_device.Cluster may now have dangling/unused result values. Remove all
344 // such return values.
345 for (auto cluster : clusters) RemoveUnusedClusterResults(cluster);
346 }
347 };
348
349 } // namespace
350
351 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorHandleCrossClusterDependencies()352 CreateDTensorHandleCrossClusterDependencies() {
353 return std::make_unique<DTensorHandleCrossClusterDependencies>();
354 }
355
356 } // namespace dtensor
357 } // namespace tensorflow
358