xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/mesh_propagation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "absl/types/optional.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/Builders.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/Value.h"  // from @llvm-project
30 #include "mlir/IR/Visitors.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassManager.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/Passes.h"  // from @llvm-project
35 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
39 #include "tensorflow/dtensor/cc/constants.h"
40 #include "tensorflow/dtensor/cc/tensor_layout.h"
41 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
42 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
43 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
44 #include "tensorflow/dtensor/mlir/layout_parsing.h"
45 #include "tensorflow/dtensor/mlir/op_utils.h"
46 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
47 
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51 
52 // Extracts mesh of `block_arg` by parsing function argument attributes of it's
53 // enclosing function. Mesh is inferred either using `tf._layout` or `tf._mesh`
54 // attributes.
ExtractMeshFromBlockArgument(mlir::BlockArgument block_arg,absl::optional<Mesh> * out)55 mlir::LogicalResult ExtractMeshFromBlockArgument(mlir::BlockArgument block_arg,
56                                                  absl::optional<Mesh>* out) {
57   auto func_op = mlir::dyn_cast_or_null<mlir::func::FuncOp>(
58       block_arg.getOwner()->getParentOp());
59   if (!func_op) {
60     return block_arg.getOwner()->getParentOp()->emitOpError(
61         "must be enclosed by a function");
62   }
63   auto layout_or_status = ExtractLayoutFromOperand(block_arg);
64   if (!layout_or_status.ok())
65     return func_op.emitOpError(layout_or_status.status().error_message());
66 
67   if (layout_or_status->has_value()) {
68     out->emplace(layout_or_status->value().mesh());
69     return mlir::success();
70   }
71 
72   auto mesh_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
73       block_arg.getArgNumber(), kCustomDeviceMeshAttr);
74   if (!mesh_attr) return mlir::success();
75 
76   auto mesh_from_block_arg_or_status =
77       Mesh::FromString(mesh_attr.getValue().str());
78   if (!mesh_from_block_arg_or_status.ok()) {
79     return func_op.emitOpError(
80         "Failed during mesh propagation. Op operand has invalid serialized "
81         "mesh");
82   }
83 
84   out->emplace(mesh_from_block_arg_or_status.ValueOrDie());
85   return mlir::success();
86 }
87 
88 // Extracts mesh of operation that produces `value`.
ExtractMeshFromOpOutput(mlir::Value value,absl::optional<Mesh> * out)89 mlir::LogicalResult ExtractMeshFromOpOutput(mlir::Value value,
90                                             absl::optional<Mesh>* out) {
91   auto input_op = value.getDefiningOp();
92   if (!input_op) return mlir::success();
93 
94   auto operand_cluster =
95       llvm::dyn_cast<mlir::tf_device::ClusterOp>(value.getDefiningOp());
96   if (!operand_cluster) {
97     return mlir::emitError(value.getLoc())
98            << "operand must be from different device cluster.";
99   }
100 
101   auto mesh_or_status = ExtractDeviceMeshFromOp(operand_cluster);
102   if (!mesh_or_status.ok())
103     return operand_cluster.emitOpError(
104         llvm::formatv("Failed during mesh propagation. {0}",
105                       mesh_or_status.status().error_message()));
106 
107   auto extracted_mesh = mesh_or_status.ValueOrDie();
108   if (extracted_mesh) *out = extracted_mesh.value();
109   return mlir::success();
110 }
111 
112 // Extracts mesh configuration from `operand`. If operand is a function
113 // argument, then mesh config is extracted from "tf._mesh" arg attribute of the
114 // corresponding func op. If operand is from a preceding op, then mesh
115 // configuration is extracted from the enclosing tf_device.Cluster op.
ExtractMeshFromOperand(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::OpOperand * operand,absl::optional<Mesh> * out)116 mlir::LogicalResult ExtractMeshFromOperand(
117     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
118     mlir::OpOperand* operand, absl::optional<Mesh>* out) {
119   mlir::Value operand_value = operand->get();
120 
121   const auto check_and_assign_mesh =
122       [](mlir::Location loc, absl::optional<Mesh>& mesh,
123          absl::optional<Mesh>& operand_mesh) -> mlir::LogicalResult {
124     if (mesh && !operand_mesh) {
125       operand_mesh.swap(mesh);
126     } else if (mesh && operand_mesh && mesh != operand_mesh) {
127       return mlir::emitError(
128           loc,
129           "Error during mesh propagation. Found inconsistent mesh "
130           "while inferring mesh from operands.");
131     }
132     return mlir::success();
133   };
134 
135   // If `operand` is a block argument then extract mesh from `tf._mesh`
136   // attribute of the corresponding function argument.
137   if (auto block_arg = operand_value.dyn_cast<mlir::BlockArgument>()) {
138     if (mlir::failed(ExtractMeshFromBlockArgument(block_arg, out)))
139       return mlir::failure();
140 
141     if (!out->has_value()) {
142       auto it = producers.find(operand);
143       if (it != producers.end()) {
144         auto producer_values = it->getSecond();
145         absl::optional<Mesh> operand_mesh;
146         for (mlir::Value producer_value : producer_values) {
147           if (auto arg = producer_value.dyn_cast<mlir::BlockArgument>()) {
148             absl::optional<Mesh> mesh;
149             if (mlir::failed(ExtractMeshFromBlockArgument(arg, &mesh)))
150               return mlir::failure();
151 
152             if (mlir::failed(check_and_assign_mesh(
153                     operand->getOwner()->getLoc(), mesh, operand_mesh)))
154               return mlir::failure();
155           } else {
156             auto input_cluster =
157                 producer_value.getDefiningOp()
158                     ->getParentOfType<mlir::tf_device::ClusterOp>();
159             auto output_from_producing_op = input_cluster.getResult(
160                 producer_value.cast<mlir::OpResult>().getResultNumber());
161 
162             absl::optional<Mesh> mesh;
163             if (mlir::failed(
164                     ExtractMeshFromOpOutput(output_from_producing_op, &mesh)))
165               return mlir::failure();
166 
167             if (mlir::failed(check_and_assign_mesh(
168                     operand->getOwner()->getLoc(), mesh, operand_mesh)))
169               return mlir::failure();
170           }
171         }
172         *out = operand_mesh;
173       }
174     }
175     return mlir::success();
176   }
177 
178   // If `operand` is from another operation, extract mesh from enclosing
179   // tf_device.cluster op of the input operation.
180   if (mlir::failed(ExtractMeshFromOpOutput(operand_value, out)))
181     return mlir::failure();
182 
183   return mlir::success();
184 }
185 
186 // Infers mesh of `cluster` from it's operands. If mesh can be inferred, all
187 // operands must have same mesh.
InferMeshFromInputs(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,absl::optional<Mesh> * mesh,llvm::SmallVector<mlir::OpOperand *,8> * inputs_with_inferred_mesh)188 mlir::LogicalResult InferMeshFromInputs(
189     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
190     mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
191     llvm::SmallVector<mlir::OpOperand*, 8>* inputs_with_inferred_mesh) {
192   auto result = mlir::success();
193 
194   // If `cluster` wraps a `tf.CopyToMesh` op, do not infer mesh from it's
195   // inputs. `tf.CopyToMesh` specifies that all operations following the
196   // operation is executed on target device mesh cluster specified by
197   // `tf.CopyToMesh`.
198   if (llvm::isa<mlir::TF::CopyToMeshOp>(&cluster.GetBody().front()))
199     return result;
200 
201   mlir::visitUsedValuesDefinedAbove(
202       cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
203         if (mlir::failed(result)) return;
204         absl::optional<Mesh> extracted_config;
205 
206         // If inputs to mesh is from DTensorLayout op, then use the mesh
207         // extracted from the DTensorLayout op to infer the mesh of the cluster.
208         if (auto layout_op =
209                 llvm::dyn_cast<mlir::TF::DTensorLayout>(operand->getOwner())) {
210           auto mesh = layout_op.layout().mesh();
211           extracted_config.emplace(mesh);
212         } else {
213           auto extract_result =
214               ExtractMeshFromOperand(producers, operand, &extracted_config);
215           if (mlir::failed(extract_result)) {
216             result = extract_result;
217             return;
218           }
219         }
220 
221         // DTensorDevice may create a graph with resource arguments with an
222         // empty layout. These layouts of the resource values will be added
223         // after layout is inferred from resource update ops. Therefore, ignore
224         // DTensorLayout ops will empty layouts.
225         if (!extracted_config || extracted_config->IsEmpty()) return;
226 
227         inputs_with_inferred_mesh->emplace_back(operand);
228         if (mesh->has_value() && extracted_config != mesh->value()) {
229           result = cluster.emitOpError(
230               "failed during mesh propagation. All inputs to "
231               "`tf_device.Cluster` must have same mesh configuration.");
232         }
233 
234         if (!mesh->has_value()) mesh->emplace(extracted_config.value());
235       });
236 
237   return result;
238 }
239 
240 // Extracts mesh from function return attributes. If `tf._default_layout`
241 // attribute exists, mesh from the default layout is used. If not, mesh from
242 // `tf._mesh` attribute is used.
ExtractMeshFromFuctionOutput(const int output_index,mlir::func::FuncOp function)243 StatusOr<absl::optional<Mesh>> ExtractMeshFromFuctionOutput(
244     const int output_index, mlir::func::FuncOp function) {
245   absl::optional<Mesh> function_mesh;
246   auto terminator = llvm::cast<mlir::func::ReturnOp>(
247       function.getBody().front().getTerminator());
248   TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromFunctionReturnAttr(
249                                        terminator, output_index));
250 
251   if (layout) {
252     function_mesh.emplace(layout->mesh());
253     return function_mesh;
254   }
255 
256   auto output_mesh_attr = function.getResultAttrOfType<mlir::StringAttr>(
257       output_index, kCustomDeviceMeshAttr);
258   if (output_mesh_attr) {
259     TF_ASSIGN_OR_RETURN(auto mesh,
260                         Mesh::FromString(output_mesh_attr.getValue().str()));
261     function_mesh.emplace(std::move(mesh));
262   }
263   return function_mesh;
264 }
265 
266 // Infers mesh from users of `cluster` and records the usages that were used to
267 // infer mesh configuration in `consumers_with_mesh`.
InferMeshFromConsumers(mlir::tf_device::ClusterOp cluster,absl::optional<Mesh> * mesh,llvm::SmallVector<mlir::OpOperand *,8> * consumers_with_mesh)268 mlir::LogicalResult InferMeshFromConsumers(
269     mlir::tf_device::ClusterOp cluster, absl::optional<Mesh>* mesh,
270     llvm::SmallVector<mlir::OpOperand*, 8>* consumers_with_mesh) {
271   for (auto& use_value : cluster.getOperation()->getUses()) {
272     mlir::Operation* consumer = use_value.getOwner();
273 
274     // `tf.CopyToMesh` specifies that all operations following the
275     // operation are executed on target device mesh cluster specified by
276     // `tf.CopyToMesh`. Therefore, if `consumer` operation is `tf.CopyToMesh`
277     // do not propagate mesh backwards to `cluster`.
278     if (llvm::isa<mlir::TF::CopyToMeshOp>(consumer)) continue;
279 
280     Mesh extracted_mesh;
281 
282     // If `cluster` output is output value of a function, then infer mesh using
283     // function return value attribute, if it exists.
284     if (auto return_op = llvm::dyn_cast<mlir::func::ReturnOp>(consumer)) {
285       auto status_or_mesh = ExtractMeshFromFuctionOutput(
286           use_value.getOperandNumber(),
287           return_op->getParentOfType<mlir::func::FuncOp>());
288       if (!status_or_mesh.ok())
289         return cluster.emitOpError(status_or_mesh.status().ToString());
290 
291       auto mesh = status_or_mesh.ValueOrDie();
292       if (mesh) extracted_mesh = *mesh;
293     } else {
294       // If `cluster` output is input to another cluster/op then infer mesh from
295       // the consumer operation.
296       auto consumer_cluster =
297           consumer->getParentOfType<mlir::tf_device::ClusterOp>();
298       if (!consumer_cluster) {
299         return cluster.emitOpError(
300             "failed to propagate mesh information. All operations must be "
301             "enclosed inside a tf_device.cluster op.");
302       }
303 
304       auto mesh_or_status = ExtractDeviceMeshFromOp(consumer_cluster);
305       if (!mesh_or_status.ok())
306         return cluster.emitOpError(mesh_or_status.status().error_message());
307 
308       auto consumer_mesh = mesh_or_status.ValueOrDie();
309       if (!consumer_mesh) continue;
310 
311       extracted_mesh = consumer_mesh.value();
312     }
313 
314     if (extracted_mesh.IsEmpty()) continue;
315 
316     if (mesh->has_value() && extracted_mesh != mesh->value()) {
317       return cluster.emitOpError(
318           "failed to propagate mesh information. Mesh for op is ambiguous as "
319           "consumers have different mesh attributes");
320     }
321 
322     consumers_with_mesh->emplace_back(&use_value);
323     if (!mesh->has_value()) mesh->emplace(std::move(extracted_mesh));
324   }
325   return mlir::success();
326 }
327 
328 // Infers default mesh of function given it's inputs and outputs. Function has a
329 // default mesh if all its inputs/outputs have valus assigned to the same mesh.
InferFunctionDefaultMesh(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,absl::optional<mlir::StringAttr> * inferred_default_mesh)330 mlir::LogicalResult InferFunctionDefaultMesh(
331     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
332     mlir::func::FuncOp function, mlir::OpBuilder* builder,
333     absl::optional<mlir::StringAttr>* inferred_default_mesh) {
334   auto terminator = function.getCallableRegion()->front().getTerminator();
335   for (auto& result_value : terminator->getOpOperands()) {
336     auto result_defining_op = result_value.get().getDefiningOp();
337     if (!result_defining_op) continue;
338 
339     auto result_cluster =
340         llvm::cast<mlir::tf_device::ClusterOp>(result_defining_op);
341     auto result_mesh =
342         result_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
343     if (!result_mesh) continue;
344 
345     if (inferred_default_mesh->has_value() &&
346         inferred_default_mesh->value() != result_mesh) {
347       inferred_default_mesh->reset();
348       return mlir::success();
349     }
350     inferred_default_mesh->emplace(result_mesh);
351   }
352 
353   absl::optional<Mesh> inferred_mesh_from_args;
354   for (auto function_arg : function.getArguments()) {
355     auto uses = function_arg.getUses();
356     if (uses.empty()) {
357       if (mlir::failed(ExtractMeshFromBlockArgument(function_arg,
358                                                     &inferred_mesh_from_args)))
359         return mlir::failure();
360     } else {
361       auto operand = uses.begin().getOperand();
362       if (mlir::failed(ExtractMeshFromOperand(producers, operand,
363                                               &inferred_mesh_from_args)))
364         return mlir::failure();
365     }
366     if (!inferred_mesh_from_args) continue;
367 
368     std::string mesh_str = inferred_mesh_from_args->ToString();
369     if (inferred_default_mesh->has_value() &&
370         inferred_default_mesh->value().getValue().str() != mesh_str) {
371       inferred_default_mesh->reset();
372       return mlir::success();
373     }
374 
375     inferred_default_mesh->emplace(builder->getStringAttr(std::move(mesh_str)));
376   }
377   return mlir::success();
378 }
379 
380 // Annotates `tf._mesh` attribute to argument of `function` with
381 // string of `mesh`.
AnnotateFunctionArgumentsWithMeshInformation(const Mesh & mesh,const llvm::SmallVector<mlir::OpOperand *,8> & input_values_from_mesh,mlir::func::FuncOp function,mlir::OpBuilder * builder)382 void AnnotateFunctionArgumentsWithMeshInformation(
383     const Mesh& mesh,
384     const llvm::SmallVector<mlir::OpOperand*, 8>& input_values_from_mesh,
385     mlir::func::FuncOp function, mlir::OpBuilder* builder) {
386   for (auto value : input_values_from_mesh) {
387     function.setArgAttr(value->getOperandNumber(), kCustomDeviceMeshAttr,
388                         builder->getStringAttr(mesh.ToString()));
389   }
390 }
391 
392 // Annotates return value attributes of `function_to_annotate` with mesh
393 // information parsed from usages of the function. `callsite_operation` is
394 // callable op whose function definition is `function_to_annotate`.
AnnotateFunctionReturnValuesWithMeshInformation(const llvm::SmallVector<mlir::OpOperand *,8> & return_values_from_mesh,mlir::Operation * callsite_operation,mlir::func::FuncOp function_to_annotate,mlir::OpBuilder * builder)395 mlir::LogicalResult AnnotateFunctionReturnValuesWithMeshInformation(
396     const llvm::SmallVector<mlir::OpOperand*, 8>& return_values_from_mesh,
397     mlir::Operation* callsite_operation,
398     mlir::func::FuncOp function_to_annotate, mlir::OpBuilder* builder) {
399   for (auto value : return_values_from_mesh) {
400     absl::optional<mlir::StringAttr> result_mesh_attribute;
401     if (llvm::isa<mlir::func::ReturnOp>(value->getOwner())) {
402       auto parent_function =
403           callsite_operation->getParentOfType<mlir::func::FuncOp>();
404       auto function_result_layout =
405           parent_function.getResultAttrOfType<mlir::StringAttr>(
406               value->getOperandNumber(), kCustomDefaultLayoutAttr);
407       if (function_result_layout) {
408         auto layout_or_status =
409             Layout::FromString(function_result_layout.getValue().str());
410         if (!layout_or_status.ok())
411           return parent_function.emitOpError(
412               layout_or_status.status().error_message());
413 
414         result_mesh_attribute.emplace(
415             builder->getStringAttr(layout_or_status->mesh().ToString()));
416       } else {
417         auto function_result_mesh =
418             parent_function.getResultAttrOfType<mlir::StringAttr>(
419                 value->getOperandNumber(), kCustomDeviceMeshAttr);
420         if (function_result_mesh)
421           result_mesh_attribute.emplace(function_result_mesh);
422       }
423     } else {
424       auto op_mesh =
425           value->getOwner()->getAttrOfType<mlir::StringAttr>(kMeshAttr);
426       if (op_mesh) result_mesh_attribute.emplace(std::move(op_mesh));
427     }
428 
429     if (result_mesh_attribute)
430       function_to_annotate.setResultAttr(
431           value->get().cast<mlir::OpResult>().getResultNumber(),
432           kCustomDeviceMeshAttr, result_mesh_attribute.value());
433   }
434   return mlir::success();
435 }
436 
437 // MLIR pass that propagates mesh information to tf_device.Cluster ops.
438 struct DTensorMeshPropagation
439     : public DTensorMeshPropagationBase<DTensorMeshPropagation> {
runOnOperationtensorflow::dtensor::__anon2f59d5a00111::DTensorMeshPropagation440   void runOnOperation() override {
441     mlir::MLIRContext& context = getContext();
442     mlir::OpBuilder builder(&context);
443     auto module = getOperation();
444     mlir::func::FuncOp main_func =
445         module.lookupSymbol<mlir::func::FuncOp>("main");
446     if (!main_func) return;
447 
448     mlir::Dialect* tf_dialect =
449         context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
450 
451     // This maps from OpResults to a list of OpOperands that consume this.
452     // Note that this will pass over/through
453     // (Stateful)PartitionedCall and other control flow, directly connecting
454     // producing ops to their consumers in the function. I.e. it presents
455     // flattened/inlined view of the flow of data.
456     llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
457     // Maintain a reverse mapping. Note that for controlflow operations like
458     // tf.If op, there may be multiple producers for a mlir::Value.
459     llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
460 
461     // Create consumers and producers maps.
462     if (mlir::failed(
463             PopulateConsumersFromModule(&module, tf_dialect, consumers)))
464       return signalPassFailure();
465 
466     for (auto& consumer : consumers) {
467       for (auto* operand : consumer.second) {
468         producers[operand].emplace_back(consumer.first);
469       }
470     }
471 
472     bool mesh_changed = true;
473     while (mesh_changed) {
474       mesh_changed = false;
475       if (mlir::failed(
476               PropagateMesh(producers, main_func, &builder, &mesh_changed)))
477         return signalPassFailure();
478     }
479   }
480 
481   // Propagates and sets `_mesh` attributes to all clusters inside `function` if
482   // possible.
483   mlir::LogicalResult PropagateMesh(
484       const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
485           producers,
486       mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
487 
488   // Infers mesh of `cluster` from its input operations.
489   mlir::LogicalResult PropagateMeshFromInputs(
490       const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
491           producers,
492       mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
493       bool* mesh_changed);
494 
495   // Infers mesh of `cluster` from its consuming operations.
496   mlir::LogicalResult PropagateMeshFromConsumers(
497       const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
498           producers,
499       mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
500       bool* mesh_changed);
501 
502   // Assigns function default mesh to clusters with no mesh specified. Note that
503   // function has default mesh if all its dtensor inputs/outputs are assigned to
504   // a single mesh.
505   mlir::LogicalResult PropagateDefaultMeshToUnAssignedClusters(
506       const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>&
507           producers,
508       mlir::func::FuncOp, mlir::OpBuilder* builder, bool* mesh_changed);
509 };
510 
511 mlir::LogicalResult
PropagateDefaultMeshToUnAssignedClusters(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,bool * mesh_changed)512 DTensorMeshPropagation::PropagateDefaultMeshToUnAssignedClusters(
513     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
514     mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
515   absl::optional<mlir::StringAttr> mesh;
516   if (mlir::failed(
517           InferFunctionDefaultMesh(producers, function, builder, &mesh)))
518     return mlir::failure();
519 
520   llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters_without_mesh;
521   auto walk_result = function.walk([&](mlir::tf_device::ClusterOp cluster) {
522     auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
523     if (!mesh_or_status.ok()) {
524       cluster.GetBody().front().emitOpError(
525           mesh_or_status.status().error_message());
526       return mlir::WalkResult::interrupt();
527     }
528 
529     const auto& mesh = mesh_or_status.ValueOrDie();
530     if (mesh.has_value()) return mlir::WalkResult::advance();
531 
532     clusters_without_mesh.emplace_back(cluster);
533     return mlir::WalkResult::advance();
534   });
535 
536   if (walk_result.wasInterrupted()) return mlir::failure();
537 
538   if (!mesh.has_value()) return mlir::success();
539 
540   // Set function default mesh to cluster with unspecified mesh.
541   for (auto cluster_without_mesh : clusters_without_mesh) {
542     *mesh_changed = true;
543     cluster_without_mesh->setAttr(kMeshAttr, mesh.value());
544   }
545 
546   return mlir::success();
547 }
548 
PropagateMeshFromInputs(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder,bool * mesh_changed)549 mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromInputs(
550     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
551     mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
552     bool* mesh_changed) {
553   // If operation inside a mesh cluster is not a callable operation and
554   // mesh is already specified on a cluster, do nothing.
555   auto inner_func = MaybeFindFunction(&cluster.GetBody().front());
556   auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
557   if (!inner_func && cluster_mesh) return mlir::success();
558 
559   // If mesh of `cluster` is not specified, infer mesh using inputs of mesh
560   // cluster.
561   absl::optional<Mesh> extracted_mesh;
562   llvm::SmallVector<mlir::OpOperand*, 8> inputs_with_inferred_mesh;
563   if (failed(InferMeshFromInputs(producers, cluster, &extracted_mesh,
564                                  &inputs_with_inferred_mesh))) {
565     return mlir::failure();
566   }
567 
568   // If operation include 'cluster` is a function call, annotate input and
569   // output mesh of `cluster` using function argument and return value
570   // attributes, then recursively propagate mesh of the function definition.
571   if (inner_func) {
572     // All inputs to cluster must be from the same mesh. If input mesh to
573     // callable operation is inferred, then annotated the input mesh to
574     // function argument attribute so that this information can be used to
575     // infer mesh of ops inside `inner_func`.
576     if (extracted_mesh.has_value()) {
577       AnnotateFunctionArgumentsWithMeshInformation(extracted_mesh.value(),
578                                                    inputs_with_inferred_mesh,
579                                                    inner_func.value(), builder);
580     }
581 
582     // Recursively propagate mesh to clusters in function definition of
583     // `inner_func`.
584     if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
585                                    mesh_changed)))
586       return mlir::failure();
587 
588     // Once all clusters inside `inner_func` callable has been set, now we can
589     // infer mesh of `cluster`. That is, mesh of call site operation is equal
590     // to mesh of return values of the function.
591     absl::optional<mlir::StringAttr> function_mesh;
592     if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
593                                               builder, &function_mesh)))
594       return mlir::failure();
595 
596     if (function_mesh && !cluster_mesh) {
597       *mesh_changed = true;
598       cluster->setAttr(kMeshAttr, function_mesh.value());
599     }
600   } else if (!cluster_mesh && extracted_mesh.has_value()) {
601     *mesh_changed = true;
602     cluster->setAttr(kMeshAttr,
603                      builder->getStringAttr(extracted_mesh->ToString()));
604   }
605   return mlir::success();
606 }
607 
608 // Set mesh of `cluster`, inferring mesh from consumer operations of `cluster`.
PropagateMeshFromConsumers(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder,bool * mesh_changed)609 mlir::LogicalResult DTensorMeshPropagation::PropagateMeshFromConsumers(
610     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
611     mlir::tf_device::ClusterOp cluster, mlir::OpBuilder* builder,
612     bool* mesh_changed) {
613   mlir::Operation* op_inside_cluster = &cluster.GetBody().front();
614   auto inner_func = MaybeFindFunction(op_inside_cluster);
615   auto cluster_mesh = cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
616   // If mesh is already set, then do nothing.
617   if (!inner_func && cluster_mesh) return mlir::success();
618 
619   // Infer mesh of `cluster` from its output usages.
620   absl::optional<Mesh> extracted_mesh_from_consumers;
621   llvm::SmallVector<mlir::OpOperand*, 8> consumers_with_mesh_information;
622   if (failed(InferMeshFromConsumers(cluster, &extracted_mesh_from_consumers,
623                                     &consumers_with_mesh_information)))
624     return mlir::failure();
625 
626   // If operation inside mesh cluster is a function callsite operation,
627   // then propagate mesh of the function recursively.
628   if (inner_func) {
629     if (mlir::failed(AnnotateFunctionReturnValuesWithMeshInformation(
630             consumers_with_mesh_information, op_inside_cluster,
631             inner_func.value(), builder)))
632       return mlir::failure();
633 
634     if (mlir::failed(PropagateMesh(producers, inner_func.value(), builder,
635                                    mesh_changed)))
636       return mlir::failure();
637 
638     absl::optional<mlir::StringAttr> function_mesh;
639     if (mlir::failed(InferFunctionDefaultMesh(producers, inner_func.value(),
640                                               builder, &function_mesh)))
641       return mlir::failure();
642 
643     if (function_mesh && !cluster_mesh) {
644       *mesh_changed = true;
645       cluster->setAttr(kMeshAttr, function_mesh.value());
646     }
647   } else if (extracted_mesh_from_consumers && !cluster_mesh) {
648     *mesh_changed = true;
649     cluster->setAttr(kMeshAttr, builder->getStringAttr(
650                                     extracted_mesh_from_consumers->ToString()));
651   }
652   return mlir::success();
653 }
654 
655 // Propagates mesh information to all `tf_device.Cluster` ops in `function`. If
656 // `function` includes callable ops, then recursively traverse the function
657 // definition to propagate mesh information using input operands and consuming
658 // result ops. Note that at current stage of graph optimization,
659 // tf_device.cluster ops are enclosing a single operation.
PropagateMesh(const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,mlir::func::FuncOp function,mlir::OpBuilder * builder,bool * mesh_changed)660 mlir::LogicalResult DTensorMeshPropagation::PropagateMesh(
661     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
662     mlir::func::FuncOp function, mlir::OpBuilder* builder, bool* mesh_changed) {
663   // Iterate clusters in topological order propagating mesh from operations'
664   // inputs.
665   llvm::SmallVector<mlir::tf_device::ClusterOp, 8> cluster_ops;
666   for (auto cluster : function.getOps<mlir::tf_device::ClusterOp>()) {
667     cluster_ops.emplace_back(cluster);
668 
669     if (mlir::failed(
670             PropagateMeshFromInputs(producers, cluster, builder, mesh_changed)))
671       return mlir::failure();
672   }
673 
674   // Iterate clusters in reverse topological order and propagate mesh from
675   // consumers.
676   for (auto cluster : llvm::reverse(cluster_ops)) {
677     if (mlir::failed(PropagateMeshFromConsumers(producers, cluster, builder,
678                                                 mesh_changed)))
679       return mlir::failure();
680   }
681 
682   if (mlir::failed(PropagateDefaultMeshToUnAssignedClusters(
683           producers, function, builder, mesh_changed)))
684     return mlir::failure();
685 
686   return mlir::success();
687 }
688 
689 }  // namespace
690 
691 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorMeshPropagationPass()692 CreateDTensorMeshPropagationPass() {
693   return std::make_unique<DTensorMeshPropagation>();
694 }
695 
696 }  // namespace dtensor
697 }  // namespace tensorflow
698