1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "absl/types/optional.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/Builders.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Diagnostics.h" // from @llvm-project
28 #include "mlir/IR/Operation.h" // from @llvm-project
29 #include "mlir/IR/Value.h" // from @llvm-project
30 #include "mlir/IR/Visitors.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassManager.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "mlir/Transforms/Passes.h" // from @llvm-project
35 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
39 #include "tensorflow/dtensor/cc/constants.h"
40 #include "tensorflow/dtensor/cc/tensor_layout.h"
41 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
42 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
43 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
44 #include "tensorflow/dtensor/mlir/layout_parsing.h"
45 #include "tensorflow/dtensor/mlir/op_utils.h"
46 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
47
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51
52 // Extracts mesh of `block_arg` by parsing function argument attributes of it's
53 // enclosing function. Mesh is inferred either using `tf._layout` or `tf._mesh`
54 // attributes.
ExtractMeshFromBlockArgument(mlir::BlockArgument block_arg,absl::optional<Mesh> * out)55 mlir::LogicalResult ExtractMeshFromBlockArgument(mlir::BlockArgument block_arg,
56 absl::optional<Mesh>* out) {
57 auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
58 block_arg.getOwner()->getParentOp());
59 if (!func_op) {
60 return block_arg.getOwner()->getParentOp()->emitOpError(
61 "must be enclosed by a function");
62 }
63 auto layout_or_status = ExtractLayoutFromOperand(block_arg);
64 if (!layout_or_status.ok())
65 return func_op.emitOpError(layout_or_status.status().error_message());
66
67 if (layout_or_status->has_value()) {
68 out->emplace(layout_or_status->value().mesh());
69 return mlir::success();
70 }
71
72 auto mesh_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
73 block_arg.getArgNumber(), kCustomDeviceMeshAttr);
74 if (!mesh_attr) return mlir::success();
75
76 auto mesh_from_block_arg_or_status =
77 Mesh::FromString(mesh_attr.getValue().str());
78 if (!mesh_from_block_arg_or_status.ok()) {
79 return func_op.emitOpError(
80 "Failed during mesh propagation. Op operand has invalid serialized "
81 "mesh");
82 }
83
84 out->emplace(mesh_from_block_arg_or_status.ValueOrDie());
85 return mlir::success();
86 }
87
88 // Extracts mesh of operation that produces `value`.
ExtractMeshFromOpOutput(mlir::Value value,absl::optional<Mesh> * out)89 mlir::LogicalResult ExtractMeshFromOpOutput(mlir::Value value,
90 absl::optional<Mesh>* out) {
91 auto input_op = value.getDefiningOp();
92 if (!input_op) return mlir::success();
93
94 auto operand_cluster =
95 llvm::dyn_cast<mlir::tf_device::ClusterOp>(value.getDefiningOp());
96 if (!operand_cluster) {
97 return mlir::emitError(value.getLoc())
98 << "operand must be from different device cluster.";
99 }
100
101 auto mesh_or_status = ExtractDeviceMeshFromOp(operand_cluster);
102 if (!mesh_or_status.ok())
103 return operand_cluster.emitOpError(
104 llvm::formatv("Failed during mesh propagation. {0}",
105 mesh_or_status.status().error_message()));
106
107 auto extracted_mesh = mesh_or_status.ValueOrDie();
108 if (extracted_mesh) *out = extracted_mesh.value();
109 return mlir::success();
110 }
111
112 // Extracts mesh configuration from `operand`. If operand is a function
113 // argument, then mesh config is extracted from "tf._mesh" arg attribute of the
114 // corresponding func op. If operand is from a preceding op, then mesh
115 // configuration is extracted from the enclosing tf_device.Cluster op.
ExtractMeshFromOperand(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::OpOperand * operand,absl::optional<Mesh> * out)116 mlir::LogicalResult ExtractMeshFromOperand(
117 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
118 mlir::OpOperand* operand, absl::optional<Mesh>* out) {
119 mlir::Value operand_value = operand->get();
120
121 const auto check_and_assign_mesh =
122 [](mlir::Location loc, absl::optional<Mesh>& mesh,
123 absl::optional<Mesh>& operand_mesh) -> mlir::LogicalResult {
124 if (mesh && !operand_mesh) {
125 operand_mesh.swap(mesh);
126 } else if (mesh && operand_mesh && mesh != operand_mesh) {
127 return mlir::emitError(
128 loc,
129 "Error during mesh propagation. Found inconsistent mesh "
130 "while inferring mesh from operands.");
131 }
132 return mlir::success();
133 };
134
135 // If `operand` is a block argument then extract mesh from `tf._mesh`
136 // attribute of the corresponding function argument.
137 if (auto block_arg = operand_value.dyn_cast<mlir::BlockArgument>()) {
138 if (mlir::failed(ExtractMeshFromBlockArgument(block_arg, out)))
139 return mlir::failure();
140
141 if (!out->has_value()) {
142 auto it = producers.find(operand);
143 if (it != producers.end()) {
144 auto producer_values = it->getSecond();
145 absl::optional<Mesh> operand_mesh;
146 for (mlir::Value producer_value : producer_values) {
147 if (auto arg = producer_value.dyn_cast<mlir::BlockArgument>()) {
148 absl::optional<Mesh> mesh;
149 if (mlir::failed(ExtractMeshFromBlockArgument(arg, &mesh)))
150 return mlir::failure();
151
152 if (mlir::failed(check_and_assign_mesh(
153 operand->getOwner()->getLoc(), mesh, operand_mesh)))
154 return mlir::failure();
155 } else {
156 auto input_cluster =
157 producer_value.getDefiningOp()
158 ->getParentOfType<mlir::tf_device::ClusterOp>();
159 auto output_from_producing_op = input_cluster.getResult(
160 producer_value.cast<mlir::OpResult>().getResultNumber());
161
162 absl::optional<Mesh> mesh;
163 if (mlir::failed(
164 ExtractMeshFromOpOutput(output_from_producing_op, &mesh)))
165 return mlir::failure();
166
167 if (mlir::failed(check_and_assign_mesh(
168 operand->getOwner()->getLoc(), mesh, operand_mesh)))
169 return mlir::failure();
170 }
171 }
172 *out = operand_mesh;
173 }
174 }
175 return mlir::success();
176 }
177
178 // If `operand` is from another operation, extract mesh from enclosing
179 // tf_device.cluster op of the input operation.
180 if (mlir::failed(ExtractMeshFromOpOutput(operand_value, out)))
181 return mlir::failure();
182
183 return mlir::success();
184 }
185
186 // Infers mesh of `cluster` from it's operands. If mesh can be inferred, all
187 // operands must have same mesh.
InferMeshFromInputs(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,absl::optional<Mesh> * mesh,llvm::SmallVector<mlir::OpOperand *,8> * inputs_with_inferred_mesh)188 mlir::LogicalResult InferMeshFromInputs(
189 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
190 mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
191 llvm::SmallVector<mlir::OpOperand*, 8>* inputs_with_inferred_mesh) {
192 auto result = mlir::success();
193
194 // If `cluster` wraps a `tf.CopyToMesh` op, do not infer mesh from it's
195 // inputs. `tf.CopyToMesh` specifies that all operations following the
196 // operation is executed on target device mesh cluster specified by
197 // `tf.CopyToMesh`.
198 if (llvm::isa<mlir::TF::CopyToMeshOp>(&cluster.GetBody().front()))
199 return result;
200
201 mlir::visitUsedValuesDefinedAbove(
202 cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
203 if (mlir::failed(result)) return;
204 absl::optional<Mesh> extracted_config;
205
206 // If inputs to mesh is from DTensorLayout op, then use the mesh
207 // extracted from the DTensorLayout op to infer the mesh of the cluster.
208 if (auto layout_op =
209 llvm::dyn_cast<mlir::TF::DTensorLayout>(operand->getOwner())) {
210 auto mesh = layout_op.layout().mesh();
211 extracted_config.emplace(mesh);
212 } else {
213 auto extract_result =
214 ExtractMeshFromOperand(producers, operand, &extracted_config);
215 if (mlir::failed(extract_result)) {
216 result = extract_result;
217 return;
218 }
219 }
220
221 // DTensorDevice may create a graph with resource arguments with an
222 // empty layout. These layouts of the resource values will be added
223 // after layout is inferred from resource update ops. Therefore, ignore
224 // DTensorLayout ops will empty layouts.
225 if (!extracted_config || extracted_config->IsEmpty()) return;
226
227 inputs_with_inferred_mesh->emplace_back(operand);
228 if (mesh->has_value() && extracted_config != mesh->value()) {
229 result = cluster.emitOpError(
230 "failed during mesh propagation. All inputs to "
231 "`tf_device.Cluster` must have same mesh configuration.");
232 }
233
234 if (!mesh->has_value()) mesh->emplace(extracted_config.value());
235 });
236
237 return result;
238 }
239
240 // Extracts mesh from function return attributes. If `tf._default_layout`
241 // attribute exists, mesh from the default layout is used. If not, mesh from
242 // `tf._mesh` attribute is used.
ExtractMeshFromFuctionOutput(const int output_index,mlir::func::FuncOp function)243 StatusOr<absl::optional<Mesh>> ExtractMeshFromFuctionOutput(
244 const int output_index, mlir::func::FuncOp function) {
245 absl::optional<Mesh> function_mesh;
246 auto terminator = llvm::cast<mlir::func::ReturnOp>(
247 function.getBody().front().getTerminator());
248 TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromFunctionReturnAttr(
249 terminator, output_index));
250
251 if (layout) {
252 function_mesh.emplace(layout->mesh());
253 return function_mesh;
254 }
255
256 auto output_mesh_attr = function.getResultAttrOfType<mlir::StringAttr>(
257 output_index, kCustomDeviceMeshAttr);
258 if (output_mesh_attr) {
259 TF_ASSIGN_OR_RETURN(auto mesh,
260 Mesh::FromString(output_mesh_attr.getValue().str()));
261 function_mesh.emplace(std::move(mesh));
262 }
263 return function_mesh;
264 }
265
266 // Infers mesh from users of `cluster` and records the usages that were used to
267 // infer mesh configuration in `consumers_with_mesh`.
InferMeshFromConsumers(mlir::tf_device::ClusterOp cluster,absl::optional<Mesh> * mesh,llvm::SmallVector<mlir::OpOperand *,8> * consumers_with_mesh)268 mlir::LogicalResult InferMeshFromConsumers(
269 mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
270 llvm::SmallVector<mlir::OpOperand*, 8>* consumers_with_mesh) {
271 for (auto& use_value : cluster.getOperation()->getUses()) {
272 mlir::Operation* consumer = use_value.getOwner();
273
274 // `tf.CopyToMesh` specifies that all operations following the
275 // operation are executed on target device mesh cluster specified by
276 // `tf.CopyToMesh`. Therefore, if `consumer` operation is `tf.CopyToMesh`
277 // do not propagate mesh backwards to `cluster`.
278 if (llvm::isa<mlir::TF::CopyToMeshOp>(consumer)) continue;
279
280 Mesh extracted_mesh;
281
282 // If `cluster` output is output value of a function, then infer mesh using
283 // function return value attribute, if it exists.
284 if (auto return_op = llvm::dyn_cast<mlir::func::ReturnOp>(consumer)) {
285 auto status_or_mesh = ExtractMeshFromFuctionOutput(
286 use_value.getOperandNumber(),
287 return_op->getParentOfType<mlir::func::FuncOp>());
288 if (!status_or_mesh.ok())
289 return cluster.emitOpError(status_or_mesh.status().ToString());
290
291 auto mesh = status_or_mesh.ValueOrDie();
292 if (mesh) extracted_mesh = *mesh;
293 } else {
294 // If `cluster` output is input to another cluster/op then infer mesh from
295 // the consumer operation.
296 auto consumer_cluster =
297 consumer->getParentOfType<mlir::tf_device::ClusterOp>();
298 if (!consumer_cluster) {
299 return cluster.emitOpError(
300 "failed to propagate mesh information. All operations must be "
301 "enclosed inside a tf_device.cluster op.");
302 }
303
304 auto mesh_or_status = ExtractDeviceMeshFromOp(consumer_cluster);
305 if (!mesh_or_status.ok())
306 return cluster.emitOpError(mesh_or_status.status().error_message());
307
308 auto consumer_mesh = mesh_or_status.ValueOrDie();
309 if (!consumer_mesh) continue;
310
311 extracted_mesh = consumer_mesh.value();
312 }
313
314 if (extracted_mesh.IsEmpty()) continue;
315
316 if (mesh->has_value() && extracted_mesh != mesh->value()) {
317 return cluster.emitOpError(
318 "failed to propagate mesh information. Mesh for op is ambiguous as "
319 "consumers have different mesh attributes");
320 }
321
322 consumers_with_mesh->emplace_back(&use_value);
323 if (!mesh->has_value()) mesh->emplace(std::move(extracted_mesh));
324 }
325 return mlir::success();
326 }
327
328 // Infers default mesh of function given it's inputs and outputs. Function has a
329 // default mesh if all its inputs/outputs have valus assigned to the same mesh.
InferFunctionDefaultMesh(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,absl::optional<mlir::StringAttr> * inferred_default_mesh)330 mlir::LogicalResult InferFunctionDefaultMesh(
331 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
332 mlir::func::FuncOp function, mlir::OpBuilder* builder,
333 absl::optional<mlir::StringAttr>* inferred_default_mesh) {
334 auto terminator = function.getCallableRegion()->front().getTerminator();
335 for (auto& result_value : terminator->getOpOperands()) {
336 auto result_defining_op = result_value.get().getDefiningOp();
337 if (!result_defining_op) continue;
338
339 auto result_cluster =
340 llvm::cast<mlir::tf_device::ClusterOp>(result_defining_op);
341 auto result_mesh =
342 result_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
343 if (!result_mesh) continue;
344
345 if (inferred_default_mesh->has_value() &&
346 inferred_default_mesh->value() != result_mesh) {
347 inferred_default_mesh->reset();
348 return mlir::success();
349 }
350 inferred_default_mesh->emplace(result_mesh);
351 }
352
353 absl::optional<Mesh> inferred_mesh_from_args;
354 for (auto function_arg : function.getArguments()) {
355 auto uses = function_arg.getUses();
356 if (uses.empty()) {
357 if (mlir::failed(ExtractMeshFromBlockArgument(function_arg,
358 &inferred_mesh_from_args)))
359 return mlir::failure();
360 } else {
361 auto operand = uses.begin().getOperand();
362 if (mlir::failed(ExtractMeshFromOperand(producers, operand,
363 &inferred_mesh_from_args)))
364 return mlir::failure();
365 }
366 if (!inferred_mesh_from_args) continue;
367
368 std::string mesh_str = inferred_mesh_from_args->ToString();
369 if (inferred_default_mesh->has_value() &&
370 inferred_default_mesh->value().getValue().str() != mesh_str) {
371 inferred_default_mesh->reset();
372 return mlir::success();
373 }
374
375 inferred_default_mesh->emplace(builder->getStringAttr(std::move(mesh_str)));
376 }
377 return mlir::success();
378 }
379
380 // Annotates `tf._mesh` attribute to argument of `function` with
381 // string of `mesh`.
AnnotateFunctionArgumentsWithMeshInformation(const Mesh & mesh,const llvm::SmallVector<mlir::OpOperand *,8> & input_values_from_mesh,mlir::func::FuncOp function,mlir::OpBuilder * builder)382 void AnnotateFunctionArgumentsWithMeshInformation(
383 const Mesh& mesh,
384 const llvm::SmallVector<mlir::OpOperand*, 8>& input_values_from_mesh,
385 mlir::func::FuncOp function, mlir::OpBuilder* builder) {
386 for (auto value : input_values_from_mesh) {
387 function.setArgAttr(value->getOperandNumber(), kCustomDeviceMeshAttr,
388 builder->getStringAttr(mesh.ToString()));
389 }
390 }
391
392 // Annotates return value attributes of `function_to_annotate` with mesh
393 // information parsed from usages of the function. `callsite_operation` is
394 // callable op whose function definition is `function_to_annotate`.
AnnotateFunctionReturnValuesWithMeshInformation(const llvm::SmallVector<mlir::OpOperand *,8> & return_values_from_mesh,mlir::Operation * callsite_operation,mlir::func::FuncOp function_to_annotate,mlir::OpBuilder * builder)395 mlir::LogicalResult AnnotateFunctionReturnValuesWithMeshInformation(
396 const llvm::SmallVector<mlir::OpOperand*, 8>& return_values_from_mesh,
397 mlir::Operation* callsite_operation,
398 mlir::func::FuncOp function_to_annotate, mlir::OpBuilder* builder) {
399 for (auto value : return_values_from_mesh) {
400 absl::optional<mlir::StringAttr> result_mesh_attribute;
401 if (llvm::isa<mlir::func::ReturnOp>(value->getOwner())) {
402 auto parent_function =
403 callsite_operation->getParentOfType<mlir::func::FuncOp>();
404 auto function_result_layout =
405 parent_function.getResultAttrOfType<mlir::StringAttr>(
406 value->getOperandNumber(), kCustomDefaultLayoutAttr);
407 if (function_result_layout) {
408 auto layout_or_status =
409 Layout::FromString(function_result_layout.getValue().str());
410 if (!layout_or_status.ok())
411 return parent_function.emitOpError(
412 layout_or_status.status().error_message());
413
414 result_mesh_attribute.emplace(
415 builder->getStringAttr(layout_or_status->mesh().ToString()));
416 } else {
417 auto function_result_mesh =
418 parent_function.getResultAttrOfType<mlir::StringAttr>(
419 value->getOperandNumber(), kCustomDeviceMeshAttr);
420 if (function_result_mesh)
421 result_mesh_attribute.emplace(function_result_mesh);
422 }
423 } else {
424 auto op_mesh =
425 value->getOwner()->getAttrOfType<mlir::StringAttr>(kMeshAttr);
426 if (op_mesh) result_mesh_attribute.emplace(std::move(op_mesh));
427 }
428
429 if (result_mesh_attribute)
430 function_to_annotate.setResultAttr(
431 value->get().cast<mlir::OpResult>().getResultNumber(),
432 kCustomDeviceMeshAttr, result_mesh_attribute.value());
433 }
434 return mlir::success();
435 }
436
437 // MLIR pass that propagates mesh information to tf_device.Cluster ops.
438 struct DTensorMeshPropagation
439 : public DTensorMeshPropagationBase<DTensorMeshPropagation> {
runOnOperationtensorflow::dtensor::__anon2f59d5a00111::DTensorMeshPropagation440 void runOnOperation() override {
441 mlir::MLIRContext& context = getContext();
442 mlir::OpBuilder builder(&context);
443 auto module = getOperation();
444 mlir::func::FuncOp main_func =
445 module.lookupSymbol<mlir::func::FuncOp>("main");
446 if (!main_func) return;
447
448 mlir::Dialect* tf_dialect =
449 context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
450
451 // This maps from OpResults to a list of OpOperands that consume this.
452 // Note that this will pass over/through
453 // (Stateful)PartitionedCall and other control flow, directly connecting
454 // producing ops to their consumers in the function. I.e. it presents
455 // flattened/inlined view of the flow of data.
456 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
457 // Maintain a reverse mapping. Note that for controlflow operations like
458 // tf.If op, there may be multiple producers for a mlir::Value.
459 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
460
461 // Create consumers and producers maps.
462 if (mlir::failed(
463 PopulateConsumersFromModule(&module, tf_dialect, consumers)))
464 return signalPassFailure();
465
466 for (auto& consumer : consumers) {
467 for (auto* operand : consumer.second) {
468 producers[operand].emplace_back(consumer.first);
469 }
470 }
471
472 bool mesh_changed = true;
473 while (mesh_changed) {
474 mesh_changed = false;
475 if (mlir::failed(
476 PropagateMesh(producers, main_func, &builder, &mesh_changed)))
477 return signalPassFailure();
478 }
479 }
480
481 // Propagates and sets `_mesh` attributes to all clusters inside `function` if
482 // possible.
483 mlir::LogicalResult PropagateMesh(
484 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
485 producers,
486 mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
487
488 // Infers mesh of `cluster` from its input operations.
489 mlir::LogicalResult PropagateMeshFromInputs(
490 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
491 producers,
492 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
493 bool* mesh_changed);
494
495 // Infers mesh of `cluster` from its consuming operations.
496 mlir::LogicalResult PropagateMeshFromConsumers(
497 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
498 producers,
499 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
500 bool* mesh_changed);
501
502 // Assigns function default mesh to clusters with no mesh specified. Note that
503 // function has default mesh if all its dtensor inputs/outputs are assigned to
504 // a single mesh.
505 mlir::LogicalResult PropagateDefaultMeshToUnAssignedClusters(
506 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
507 producers,
508 mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
509 };
510
511 mlir::LogicalResult
PropagateDefaultMeshToUnAssignedClusters(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,bool * mesh_changed)512 DTensorMeshPropagation::PropagateDefaultMeshToUnAssignedClusters(
513 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
514 mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
515 absl::optional<mlir::StringAttr> mesh;
516 if (mlir::failed(
517 InferFunctionDefaultMesh(producers, function, builder, &mesh)))
518 return mlir::failure();
519
520 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters_without_mesh;
521 auto walk_result = function.walk([&](mlir::tf_device::ClusterOp cluster) {
522 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
523 if (!mesh_or_status.ok()) {
524 cluster.GetBody().front().emitOpError(
525 mesh_or_status.status().error_message());
526 return mlir::WalkResult::interrupt();
527 }
528
529 const auto& mesh = mesh_or_status.ValueOrDie();
530 if (mesh.has_value()) return mlir::WalkResult::advance();
531
532 clusters_without_mesh.emplace_back(cluster);
533 return mlir::WalkResult::advance();
534 });
535
536 if (walk_result.wasInterrupted()) return mlir::failure();
537
538 if (!mesh.has_value()) return mlir::success();
539
540 // Set function default mesh to cluster with unspecified mesh.
541 for (auto cluster_without_mesh : clusters_without_mesh) {
542 *mesh_changed = true;
543 cluster_without_mesh->setAttr(kMeshAttr, mesh.value());
544 }
545
546 return mlir::success();
547 }
548
PropagateMeshFromInputs(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder,bool * mesh_changed)549 mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromInputs(
550 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
551 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
552 bool* mesh_changed) {
553 // If operation inside a mesh cluster is not a callable operation and
554 // mesh is already specified on a cluster, do nothing.
555 auto inner_func = MaybeFindFunction(&cluster.GetBody().front());
556 auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
557 if (!inner_func && cluster_mesh) return mlir::success();
558
559 // If mesh of `cluster` is not specified, infer mesh using inputs of mesh
560 // cluster.
561 absl::optional<Mesh> extracted_mesh;
562 llvm::SmallVector<mlir::OpOperand*, 8> inputs_with_inferred_mesh;
563 if (failed(InferMeshFromInputs(producers, cluster, &extracted_mesh,
564 &inputs_with_inferred_mesh))) {
565 return mlir::failure();
566 }
567
568 // If operation include 'cluster` is a function call, annotate input and
569 // output mesh of `cluster` using function argument and return value
570 // attributes, then recursively propagate mesh of the function definition.
571 if (inner_func) {
572 // All inputs to cluster must be from the same mesh. If input mesh to
573 // callable operation is inferred, then annotated the input mesh to
574 // function argument attribute so that this information can be used to
575 // infer mesh of ops inside `inner_func`.
576 if (extracted_mesh.has_value()) {
577 AnnotateFunctionArgumentsWithMeshInformation(extracted_mesh.value(),
578 inputs_with_inferred_mesh,
579 inner_func.value(), builder);
580 }
581
582 // Recursively propagate mesh to clusters in function definition of
583 // `inner_func`.
584 if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
585 mesh_changed)))
586 return mlir::failure();
587
588 // Once all clusters inside `inner_func` callable has been set, now we can
589 // infer mesh of `cluster`. That is, mesh of call site operation is equal
590 // to mesh of return values of the function.
591 absl::optional<mlir::StringAttr> function_mesh;
592 if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
593 builder, &function_mesh)))
594 return mlir::failure();
595
596 if (function_mesh && !cluster_mesh) {
597 *mesh_changed = true;
598 cluster->setAttr(kMeshAttr, function_mesh.value());
599 }
600 } else if (!cluster_mesh && extracted_mesh.has_value()) {
601 *mesh_changed = true;
602 cluster->setAttr(kMeshAttr,
603 builder->getStringAttr(extracted_mesh->ToString()));
604 }
605 return mlir::success();
606 }
607
608 // Set mesh of `cluster`, inferring mesh from consumer operations of `cluster`.
PropagateMeshFromConsumers(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder,bool * mesh_changed)609 mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromConsumers(
610 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
611 mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
612 bool* mesh_changed) {
613 mlir::Operation* op_inside_cluster = &cluster.GetBody().front();
614 auto inner_func = MaybeFindFunction(op_inside_cluster);
615 auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
616 // If mesh is already set, then do nothing.
617 if (!inner_func && cluster_mesh) return mlir::success();
618
619 // Infer mesh of `cluster` from its output usages.
620 absl::optional<Mesh> extracted_mesh_from_consumers;
621 llvm::SmallVector<mlir::OpOperand*, 8> consumers_with_mesh_information;
622 if (failed(InferMeshFromConsumers(cluster, &extracted_mesh_from_consumers,
623 &consumers_with_mesh_information)))
624 return mlir::failure();
625
626 // If operation inside mesh cluster is a function callsite operation,
627 // then propagate mesh of the function recursively.
628 if (inner_func) {
629 if (mlir::failed(AnnotateFunctionReturnValuesWithMeshInformation(
630 consumers_with_mesh_information, op_inside_cluster,
631 inner_func.value(), builder)))
632 return mlir::failure();
633
634 if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
635 mesh_changed)))
636 return mlir::failure();
637
638 absl::optional<mlir::StringAttr> function_mesh;
639 if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
640 builder, &function_mesh)))
641 return mlir::failure();
642
643 if (function_mesh && !cluster_mesh) {
644 *mesh_changed = true;
645 cluster->setAttr(kMeshAttr, function_mesh.value());
646 }
647 } else if (extracted_mesh_from_consumers && !cluster_mesh) {
648 *mesh_changed = true;
649 cluster->setAttr(kMeshAttr, builder->getStringAttr(
650 extracted_mesh_from_consumers->ToString()));
651 }
652 return mlir::success();
653 }
654
655 // Propagates mesh information to all `tf_device.Cluster` ops in `function`. If
656 // `function` includes callable ops, then recursively traverse the function
657 // definition to propagate mesh information using input operands and consuming
658 // result ops. Note that at current stage of graph optimization,
659 // tf_device.cluster ops are enclosing a single operation.
PropagateMesh(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,bool * mesh_changed)660 mlir::LogicalResult DTensorMeshPropagation::PropagateMesh(
661 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
662 mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
663 // Iterate clusters in topological order propagating mesh from operations'
664 // inputs.
665 llvm::SmallVector<mlir::tf_device::ClusterOp, 8> cluster_ops;
666 for (auto cluster : function.getOps<mlir::tf_device::ClusterOp>()) {
667 cluster_ops.emplace_back(cluster);
668
669 if (mlir::failed(
670 PropagateMeshFromInputs(producers, cluster, builder, mesh_changed)))
671 return mlir::failure();
672 }
673
674 // Iterate clusters in reverse topological order and propagate mesh from
675 // consumers.
676 for (auto cluster : llvm::reverse(cluster_ops)) {
677 if (mlir::failed(PropagateMeshFromConsumers(producers, cluster, builder,
678 mesh_changed)))
679 return mlir::failure();
680 }
681
682 if (mlir::failed(PropagateDefaultMeshToUnAssignedClusters(
683 producers, function, builder, mesh_changed)))
684 return mlir::failure();
685
686 return mlir::success();
687 }
688
689 } // namespace
690
691 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorMeshPropagationPass()692 CreateDTensorMeshPropagationPass() {
693 return std::make_unique<DTensorMeshPropagation>();
694 }
695
696 } // namespace dtensor
697 } // namespace tensorflow
698