xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/layout_propagation_v2.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <queue>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
38 #include "mlir/IR/Types.h"  // from @llvm-project
39 #include "mlir/IR/Value.h"  // from @llvm-project
40 #include "mlir/IR/Visitors.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
46 #include "tensorflow/compiler/mlir/utils/name_utils.h"
47 #include "tensorflow/dtensor/cc/constants.h"
48 #include "tensorflow/dtensor/cc/dtensor_utils.h"
49 #include "tensorflow/dtensor/cc/tensor_layout.h"
50 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
51 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
52 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
53 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
54 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
55 #include "tensorflow/dtensor/mlir/layout_parsing.h"
56 #include "tensorflow/dtensor/mlir/op_utils.h"
57 #include "tensorflow/dtensor/mlir/spmd_expander.h"
58 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
59 #include "tensorflow/dtensor/mlir/value_utils.h"
60 
61 namespace tensorflow {
62 namespace dtensor {
63 
64 // This value dictates how many times during layout propagation we allow
65 // fixing of oscillatory behaviors.
66 constexpr int kLayoutPropagationMaxStages = 3;
67 
AllOpResultsHaveLayouts(mlir::ModuleOp * module,mlir::Dialect * tf_dialect,const llvm::DenseMap<mlir::Value,Layout> & layouts)68 bool AllOpResultsHaveLayouts(
69     mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
70     const llvm::DenseMap<mlir::Value, Layout>& layouts) {
71   const auto& result = module->walk([&](mlir::Operation* op) {
72     if (op->getDialect() != tf_dialect ||
73         mlir::isa<mlir::TF::DTensorLayout>(op))
74       return mlir::WalkResult::advance();
75     for (const auto& result : op->getOpResults()) {
76       if (layouts.find(result) == layouts.end()) {
77         op->emitOpError() << "missing layout for result "
78                           << result.getResultNumber();
79         return mlir::WalkResult::interrupt();
80       }
81     }
82     return mlir::WalkResult::advance();
83   });
84   return !result.wasInterrupted();
85 }
86 
UpdateLayoutForSkippedOps(mlir::OpOperand & operand,const llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller,const Layout & layout_to_copy,llvm::DenseMap<mlir::Value,Layout> & layouts)87 void UpdateLayoutForSkippedOps(
88     mlir::OpOperand& operand,
89     const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
90     const Layout& layout_to_copy,
91     llvm::DenseMap<mlir::Value, Layout>& layouts) {
92   llvm::SmallVector<mlir::Value, 4> skipped_values;
93   TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
94   for (const mlir::Value& skipped_value : skipped_values)
95     if ((!skipped_value.isa<mlir::OpResult>() ||
96          !mlir::isa<mlir::TF::DTensorLayout, mlir::tf_device::ClusterOp>(
97              skipped_value.getDefiningOp())) &&
98         layouts.find(skipped_value) == layouts.end())
99       // TraceUseToNextTFOp's contract is that it only skips over ops that
100       // act like the identity (such as function calls, returns, yields,
101       // controlflow, DTensorLayouts, etc). This means that operand layout
102       // that we came from is the layout we want for this value.
103       layouts[skipped_value] = layout_to_copy;
104 }
105 
106 // Some ops, which are skipped by TraceUseToNextTFOp, will not have layouts
107 // for their mlir::OpResults.
108 // E.g. during the creation of the consumers map, we skip the input and output
109 // of the WhileRegion op. In particular if we have:
110 //
111 // %b = tf.WhileRegion(%a) ({
112 //     %bb0(%arg0):  # Cond
113 //       %c = tf.A(%arg0)
114 //       tf.Yield(%c)
115 //     }, {
116 //     %bb0(%arg0):  # Body
117 //       %d = tf.B(%arg0)
118 //       tf.Yield(%d)
119 //     }
120 //   }
121 // %e = tf.C(%b)
122 //
123 // Then the consumers map would directly connect the mlir::Value %a to input 0
124 // of tf.A and tf.B, bypassing the WhileRegion and the mlir::Value of %arg1.
125 // Similarly it would connect the mlir::Value of %d directly to input 0 of
126 // tf.C bypassing the mlir::Value of %b.
127 // This means that at the end of layout propagation the skipped values would not
128 // have an assigned layout. But this layout can be derived by taking the known
129 // layout of %a and propagating to each mlir::Value that was skipped while
130 // connecting %a to the input 0 of tf.A and tf.B. Similarly we derive the layout
131 // for %b from %d.
132 //
133 // To get layouts we
134 // 1) Iterate over all ops that have layouts for their OpResults and call
135 //    TraceUseToNextTFOp to get the skipped mlir::Values.
136 // 2) If any skipped mlir::Value doesn't have a layout set, then we set the
137 //    layout.
CopyLayoutsForSkippedOps(mlir::ModuleOp module,mlir::Dialect * tf_dialect,llvm::DenseMap<mlir::Value,Layout> & layouts)138 mlir::LogicalResult CopyLayoutsForSkippedOps(
139     mlir::ModuleOp module, mlir::Dialect* tf_dialect,
140     llvm::DenseMap<mlir::Value, Layout>& layouts) {
141   llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
142 
143   if (mlir::failed(GetFuncToCaller(module, func_to_caller)))
144     return mlir::failure();
145 
146   // Update layouts derived from ops.
147   module->walk([&](mlir::Operation* op) {
148     for (mlir::OpOperand& operand : op->getOpOperands()) {
149       if (layouts.find(operand.get()) == layouts.end()) continue;
150       const Layout layout = layouts[operand.get()];
151       UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
152     }
153   });
154 
155   // Update layouts derived from inputs
156   mlir::func::FuncOp main_func =
157       module.lookupSymbol<mlir::func::FuncOp>("main");
158   if (!main_func) return mlir::success();
159 
160   for (auto& value : main_func.getArguments()) {
161     if (layouts.find(value) == layouts.end()) continue;
162     const Layout layout = layouts[value];
163 
164     for (mlir::OpOperand& operand : value.getUses())
165       UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
166   }
167 
168   return mlir::success();
169 }
170 
171 namespace {
FilterkAnySpecs(std::vector<std::string> & proposed_specs)172 void FilterkAnySpecs(std::vector<std::string>& proposed_specs) {
173   for (auto& spec : proposed_specs) {
174     if (spec == Layout::kAny) spec = Layout::kUnshardedDim;
175   }
176 }
177 }  // namespace
178 
179 // Merges the producer and consumer layouts into a single layout.
180 // Assumes that all layouts are of the same rank.
181 // Consumers are first merged together so that we have the layout which is
182 // sharded in a tensor dim if and only if all consumers are sharded in the same
183 // sharding_spec.
184 // If producer layout is present, we merge the consumer layouts into the layout
185 // of the producer: if the consumer wants a sharded layout in a tensor dimension
186 // where the producer is unshared *and* the mesh dimension it wants to be
187 // sharded over is not already sharded over by the producer, then we add that
188 // sharding to the producer layout.
MergeLayouts(const absl::optional<Layout> & producer,const mlir::DenseMap<mlir::OpOperand *,Layout> & consumers)189 StatusOr<Layout> MergeLayouts(
190     const absl::optional<Layout>& producer,
191     const mlir::DenseMap<mlir::OpOperand*, Layout>& consumers) {
192   if (consumers.empty()) return producer.value();
193 
194   // Initialize the specs to those of the first consumer layout and merge
195   // consumers.
196   std::vector<std::string> proposed_specs =
197       consumers.begin()->second.sharding_spec_strs();
198   int layout_rank = proposed_specs.size();
199 
200   // Verify consumer layout ranks match.
201   for (const auto& consumer : consumers) {
202     const Layout& consumer_layout = consumer.second;
203     if (consumer_layout.rank() != layout_rank)
204       return errors::InvalidArgument(
205           "found two consumer layout of different ranks: ",
206           consumer_layout.rank(), " and ", layout_rank);
207   }
208 
209   // Merge consumer layouts.
210   for (const auto& consumer : consumers) {
211     const Layout& consumer_layout = consumer.second;
212 
213     // Check every tensor dimension.
214     for (int j = 0; j < consumer_layout.rank(); ++j) {
215       const std::string& consumer_spec_j = consumer_layout.sharding_spec(j);
216       if (consumer_spec_j == Layout::kAny) continue;
217 
218       // If the proposed spec is set as any, give priority to the consumer spec.
219       if (proposed_specs[j] == Layout::kAny) {
220         proposed_specs[j] = consumer_spec_j;
221         continue;
222       }
223 
224       // If any consumer layout disagrees with the current merge, set the
225       // spec to not sharded.
226       if (proposed_specs[j] != consumer_spec_j)
227         proposed_specs[j] = Layout::kUnshardedDim;
228     }
229   }
230 
231   // Filter over-sharded specs.
232   absl::flat_hash_map<std::string, int> counter;
233   for (const std::string& spec : proposed_specs) counter[spec] += 1;
234   for (std::string& spec : proposed_specs)
235     if (counter[spec] > 1) spec = Layout::kUnshardedDim;
236 
237   // Return layout if there is no producer, else move into producer algorithm.
238   const Mesh mesh = consumers.begin()->second.mesh();
239   if (!producer) {
240     FilterkAnySpecs(proposed_specs);
241     return Layout::GetLayout(proposed_specs, mesh);
242   }
243 
244   if (producer->rank() != layout_rank) {
245     return errors::InvalidArgument(
246         "producer and consumer layout have different ranks: ", producer->rank(),
247         " and ", layout_rank);
248   }
249 
250   // For the producer merge, first we define mesh dims used by the producer to
251   // avoid creating a layout that shards twice over the same mesh dim.
252   llvm::DenseSet<llvm::StringRef> producer_dims;
253   for (int j = 0; j < producer->rank(); ++j) {
254     llvm::StringRef spec = producer->sharding_spec(j);
255     if (Layout::IsShardedDimension(spec.str())) producer_dims.insert(spec);
256   }
257   // Merge producer layout with existing layout.
258   for (int j = 0; j < producer->rank(); ++j) {
259     const std::string& producer_spec = producer->sharding_spec(j);
260 
261     if (producer_spec == proposed_specs[j] || producer_spec == Layout::kAny)
262       continue;
263 
264     if (proposed_specs[j] == Layout::kAny) {
265       proposed_specs[j] = producer_spec;
266       continue;
267     }
268     // If producer is unsharded and proposed_spec is sharded. Need to make sure
269     // mesh dim is not used elsewhere. If so, set to unsharded.
270     if (Layout::IsUnshardedDimension(producer_spec)) {
271       bool isMeshDimUsed = producer_dims.contains(proposed_specs[j]);
272       if (isMeshDimUsed) {
273         proposed_specs[j] = Layout::kUnshardedDim;
274       }
275     } else {
276       // If producer is sharded we can set layout to shard over same
277       // mesh dim.
278       //
279       // If mesh dim is already used in the layout elsewhere it will
280       // get unset by the case above.
281       proposed_specs[j] = producer_spec;
282     }
283   }
284   FilterkAnySpecs(proposed_specs);
285   return Layout::GetLayout(proposed_specs, mesh);
286 }
287 
InsertLayoutsForDTensorLayout(mlir::ModuleOp & module,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseSet<mlir::Value> & is_locked)288 mlir::LogicalResult InsertLayoutsForDTensorLayout(
289     mlir::ModuleOp& module,
290     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
291     llvm::DenseSet<mlir::Value>& is_updated,
292     llvm::DenseSet<mlir::Value>& is_locked) {
293   return mlir::failure(
294       module
295           .walk([&](mlir::TF::DTensorLayout op) -> mlir::WalkResult {
296             // Check there are no "Layout::kAny" or "kMatch" specs in the
297             // layouts.
298             for (const std::string& spec : op.layout().sharding_spec_strs())
299               if (spec == Layout::kAny || spec == Layout::kMatch)
300                 return op->emitOpError()
301                        << "found " << spec
302                        << " as a sharding spec which is not allowed";
303             // Insert layout.
304             producer_request[op.input()].emplace(op.layout());
305             is_updated.insert(op.input());
306             is_locked.insert(op.input());
307             return mlir::WalkResult::advance();
308           })
309           .wasInterrupted());
310 }
311 
312 // Runs ComputeLayout API on all ops inside graph **without** any consumer
313 // requested layout/ operand layouts populated.
InsertInitialLayoutsFromComputeLayout(mlir::ModuleOp module,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseSet<mlir::Value> & is_updated)314 mlir::LogicalResult InsertInitialLayoutsFromComputeLayout(
315     mlir::ModuleOp module,
316     const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
317     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
318     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
319     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
320         consumer_requests,
321     llvm::DenseSet<mlir::Value>& is_updated) {
322   auto walk_result = module.walk([&](mlir::Operation* op) {
323     // We ignore ops that don't have either an OpResult in consumers or an
324     // OpOperand in producers. Note that if one operand is missing from
325     // producers then all operands should be missing as well as all op results
326     // from consumers and the opposite as well.
327 
328     if (op->getNumOperands() > 0) {
329       if (producers.find(&op->getOpOperand(0)) == producers.end())
330         return mlir::WalkResult::advance();
331     } else if (op->getNumResults() > 0) {
332       if (consumers.find(op->getOpResult(0)) == consumers.end())
333         return mlir::WalkResult::advance();
334     } else {
335       // Note that this case should never happen (I.e. a TF ops should have
336       // either inputs or outputs, but that isn't technically guaranteed).
337       return mlir::WalkResult::advance();
338     }
339 
340     auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
341     if (expander == nullptr) {
342       op->emitOpError() << "does not implement layout propagation";
343       return mlir::WalkResult::interrupt();
344     }
345 
346     // Invoke ComputeLayout on `cluster_op` with empty input/consumer layouts.
347     StatusOr<llvm::DenseMap<int, Layout>> forward_result =
348         expander->ComputeLayoutForward(
349             op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
350             /*output_layouts=*/llvm::DenseMap<int, Layout>());
351     if (!forward_result.ok()) {
352       op->emitOpError() << "ComputeLayoutForward error: "
353                         << forward_result.status().error_message();
354       return mlir::WalkResult::interrupt();
355     }
356     StatusOr<llvm::DenseMap<int, Layout>> backward_result =
357         expander->ComputeLayoutBackward(
358             op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
359             /*output_layouts=*/llvm::DenseMap<int, Layout>());
360     if (!backward_result.ok()) {
361       op->emitOpError() << "ComputeLayoutBackward error: "
362                         << backward_result.status().error_message();
363       return mlir::WalkResult::interrupt();
364     }
365 
366     // If any operand layouts were returned, add the layout to consumer requests
367     // and set the value as updated.
368     for (auto const& op_idx_and_op_layout : *backward_result) {
369       auto const& op_idx = op_idx_and_op_layout.first;
370       auto const& op_layout = op_idx_and_op_layout.second;
371       auto& operand = op->getOpOperand(op_idx);
372       const auto& producer_values = producers.lookup(&operand);
373       for (mlir::Value producer_value : producer_values) {
374         if (!consumer_requests[producer_value].count(&operand))
375           consumer_requests[producer_value][&operand] = op_layout;
376 
377         is_updated.insert(producer_value);
378       }
379     }
380 
381     // If any output layouts were returned, add the layout to producer requests
382     // and set the value as updated.
383     for (auto const& out_idx_and_out_layout : *forward_result) {
384       auto const& out_idx = out_idx_and_out_layout.first;
385       auto const& out_layout = out_idx_and_out_layout.second;
386       mlir::Value output_value = op->getResult(out_idx);
387       producer_request.try_emplace(output_value, out_layout);
388       is_updated.insert(output_value);
389     }
390 
391     return mlir::WalkResult::advance();
392   });
393   return mlir::failure(walk_result.wasInterrupted());
394 }
395 
396 // Propagates mesh and inserts initial layouts for
397 // * Any DTensorLayout ops (this handles function inputs and other ops with user
398 //   layouts.
399 // * CopyToMesh
400 // * ConstOp
InsertInitialLayouts(mlir::ModuleOp & module,mlir::func::FuncOp & main_func,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_request,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseSet<mlir::Value> & is_locked)401 mlir::LogicalResult InsertInitialLayouts(
402     mlir::ModuleOp& module, mlir::func::FuncOp& main_func,
403     const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
404     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
405     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
406         consumer_request,
407     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
408     llvm::DenseSet<mlir::Value>& is_updated,
409     llvm::DenseSet<mlir::Value>& is_locked) {
410   std::queue<mlir::Operation*> operations;
411 
412   if (mlir::failed(InsertLayoutsForDTensorLayout(module, producer_request,
413                                                  is_updated, is_locked)))
414     return mlir::failure();
415   return InsertInitialLayoutsFromComputeLayout(module, consumers, producers,
416                                                producer_request,
417                                                consumer_request, is_updated);
418 }
419 
420 // Given a list of mlir::Values with updated producer or consumer layouts
421 // update the merged_layouts list and track which layouts actually changed.
MergeAndGetUpdatedLayouts(const llvm::DenseSet<mlir::Value> & is_locked,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseMap<mlir::Value,Layout> & merged_layouts)422 mlir::LogicalResult MergeAndGetUpdatedLayouts(
423     const llvm::DenseSet<mlir::Value>& is_locked,
424     llvm::DenseSet<mlir::Value>& is_updated,
425     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
426     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
427         consumer_requests,
428     llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
429   llvm::DenseSet<mlir::Value> updated_merge;
430   for (auto& value : is_updated) {
431     auto& producer_layout = producer_request[value];
432     if (is_locked.find(value) != is_locked.end()) {
433       // Locked values must have a producer request. If the merged_layout is
434       // not already set, then this is the first pass, so we set it and mark
435       // then entry as updated.
436       if (merged_layouts.find(value) == merged_layouts.end()) {
437         if (!producer_layout)
438           return value.getDefiningOp()->emitError() << "missing locked layout";
439         merged_layouts[value] = producer_layout.value();
440         updated_merge.insert(value);
441       }
442       continue;
443     }
444     auto merged = MergeLayouts(producer_layout, consumer_requests[value]);
445     if (!merged.ok())
446       return value.getDefiningOp()->emitOpError()
447              << merged.status().error_message();
448 
449     auto current_layout = merged_layouts.find(value);
450     if (current_layout == merged_layouts.end() ||
451         current_layout->second != merged.ValueOrDie()) {
452       updated_merge.insert(value);
453       merged_layouts[value] = merged.ValueOrDie();
454     }
455   }
456 
457   is_updated = updated_merge;
458   return mlir::success();
459 }
460 
461 // Finds the most sharded merged layout given `layouts`.
GetMostShardedLayout(llvm::ArrayRef<Layout> layouts,mlir::Location location,absl::optional<Layout> * out)462 mlir::LogicalResult GetMostShardedLayout(llvm::ArrayRef<Layout> layouts,
463                                          mlir::Location location,
464                                          absl::optional<Layout>* out) {
465   // If there are no layouts to merge, leave the output empty.
466   if (layouts.empty()) return mlir::success();
467 
468   absl::optional<Layout> layout;
469   std::map<std::string, std::set<int>> layout_map;
470   for (const Layout& layout : layouts) {
471     for (int i = 0; i < layout.rank(); ++i) {
472       const std::string& mesh_dim = layout.dim(i).sharding_spec();
473       if (mesh_dim == Layout::kUnshardedDim) continue;
474 
475       layout_map[mesh_dim].insert(i);
476     }
477   }
478 
479   for (auto& it : layout_map)
480     if (it.second.size() > 1) it.second.clear();
481 
482   std::map<int, std::set<std::string>> dim_to_layout_map;
483   for (const auto& it : layout_map) {
484     assert(it.second.size() <= 1);
485     if (it.second.empty()) continue;
486 
487     const int tensor_dim_index = *it.second.begin();
488     dim_to_layout_map[tensor_dim_index].insert(it.first);
489   }
490 
491   for (auto& it : dim_to_layout_map)
492     if (it.second.size() > 1) it.second.clear();
493 
494   std::vector<std::string> merged_spec;
495   assert(!layouts.empty());
496   for (int i = 0; i < layouts[0].rank(); ++i) {
497     const auto it = dim_to_layout_map.find(i);
498     if (it != dim_to_layout_map.end() && !it->second.empty()) {
499       assert(it->second.size() == 1);
500       merged_spec.emplace_back(*it->second.begin());
501     } else {
502       merged_spec.emplace_back(Layout::kUnshardedDim);
503     }
504   }
505   const auto new_layout = Layout::GetLayout(merged_spec, layouts[0].mesh());
506   if (!new_layout.ok()) {
507     return mlir::emitError(
508         location, llvm::formatv("error in layout propagation while merging "
509                                 "producer layouts. {0}",
510                                 new_layout.status().error_message()));
511   }
512   out->emplace(*new_layout);
513   return mlir::success();
514 }
515 
516 // Merge layouts of mlir::Value from multiple producers into a single final
517 // layout. A mlir::Value can have multiple producers if the value is from a
518 // tf.If/tf.IfRegion op. Given multiple producer layouts of the same
519 // mlir::Value, the merging logic is as follows:
520 //   1) If a dimension can be sharded, shard the dimension as much as possible.
521 //   2) If mesh dimension is already used or two same mesh dimensions are used
522 //      in different dimensions, then leave the dimension as replicated.
523 //
524 // For example:
525 //  ("x", replicated) , (replicated, "y") will have ("x", "y") merged layout.
526 //  ("x", replicated) , (replicated, "x") will have (replicated, replicated)
527 // merged layout.
MergeProducerLayouts(const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,const std::vector<mlir::Value> & producer_values,mlir::Location location,absl::optional<Layout> * layout_out)528 mlir::LogicalResult MergeProducerLayouts(
529     const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
530     const std::vector<mlir::Value>& producer_values, mlir::Location location,
531     absl::optional<Layout>* layout_out) {
532   // If there is a single producer for mlir::Value, then return the layout
533   // from the producer.
534   absl::optional<Layout> layout;
535   if (producer_values.size() == 1) {
536     const auto it = merged_layouts.find(producer_values[0]);
537     if (it != merged_layouts.end()) *layout_out = it->second;
538     return mlir::success();
539   }
540 
541   // For the case with multiple producer, merge the layouts.
542   llvm::SmallVector<Layout, 4> candidate_layouts;
543   candidate_layouts.reserve(producer_values.size());
544   for (mlir::Value value : producer_values) {
545     auto it = merged_layouts.find(value);
546     if (it == merged_layouts.end()) continue;
547     candidate_layouts.emplace_back(it->second);
548   }
549 
550   if (mlir::failed(GetMostShardedLayout(candidate_layouts, location, &layout)))
551     return mlir::failure();
552 
553   if (layout) *layout_out = *layout;
554   return mlir::success();
555 }
556 
557 // For an op, calls the corresponding ComputeLayouts function with the data from
558 // the merged_layouts map. Records the result in the producer_request and
559 // consumer_requests maps and notes if any layouts have changed.
UpdateLayoutsForOp(mlir::Operation * op,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseSet<mlir::Value> & is_updated)560 mlir::LogicalResult UpdateLayoutsForOp(
561     mlir::Operation* op,
562     const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
563     const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
564     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
565     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
566         consumer_requests,
567     llvm::DenseSet<mlir::Value>& is_updated) {
568   auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
569   if (expander == nullptr)
570     return op->emitOpError() << "does not implement layout propagation";
571 
572   // Get input and output layouts for this op from the merged_layouts map.
573   llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
574   llvm::DenseMap<int, Layout> output_layouts(op->getNumResults());
575 
576   for (int i = 0; i < op->getNumOperands(); ++i) {
577     // For inputs, we need to find the producer's mlir::Value that eventually
578     // feeds into this op. This is in the producers map.
579     // Merge different layouts for multiples producers `values`.
580     auto producer_values = producers.find(&(op->getOpOperand(i)));
581     if (producer_values == producers.end())
582       return op->emitError() << "Unable to find producer for operand " << i;
583 
584     absl::optional<Layout> layout;
585     if (mlir::failed(MergeProducerLayouts(merged_layouts,
586                                           producer_values->getSecond(),
587                                           op->getLoc(), &layout)))
588       return mlir::failure();
589 
590     if (layout) input_layouts[i] = *layout;
591   }
592 
593   for (int i = 0; i < op->getNumResults(); ++i) {
594     auto layout = merged_layouts.find(op->getOpResult(i));
595     if (layout != merged_layouts.end()) output_layouts[i] = layout->second;
596   }
597 
598   auto forward_result =
599       expander->ComputeLayoutForward(op, input_layouts, output_layouts);
600   if (!forward_result.ok()) {
601     return op->emitOpError() << "ComputeLayoutForward error: "
602                              << forward_result.status().error_message();
603   }
604   const auto new_output_layouts = *forward_result;
605   auto backward_result =
606       expander->ComputeLayoutBackward(op, input_layouts, output_layouts);
607   if (!backward_result.ok()) {
608     return op->emitOpError() << "ComputeLayoutBackward error: "
609                              << backward_result.status().error_message();
610   }
611   const auto new_input_layouts = *backward_result;
612 
613   // Update the consumer layouts for this op.
614   for (int i = 0; i < op->getNumOperands(); ++i) {
615     mlir::OpOperand* operand = &(op->getOpOperand(i));
616     // No need to check that this exists, we already did it above.
617     const auto& producer_values = producers.find(operand);
618     const auto input_layout = new_input_layouts.find(i);
619 
620     for (mlir::Value value : producer_values->getSecond()) {
621       auto& consumer_request = consumer_requests[value];
622       const auto consumer_request_from_op_operand =
623           consumer_request.find(operand);
624 
625       // Update the consumer_request for this OpOperand: we respect what compute
626       // layout returns and erase the a requested layout if no layout is
627       // returned.
628       // TODO(hongjunchoi, bfontain): Consider the case when op output type is
629       // resource type with subtype information.
630       if (input_layout != new_input_layouts.end() &&
631           (consumer_request_from_op_operand == consumer_request.end() ||
632            input_layout->second != consumer_request_from_op_operand->second)) {
633         // RestoreV2 op most likely would have unknown rank upon restoring, and
634         // we relax unknown rank check for the inputs that are produced from
635         // there.
636         const bool exempt_restore_unknown_rank =
637             ValueRank(value) == -1 && value.getDefiningOp() &&
638             llvm::isa<mlir::TF::RestoreV2Op>(value.getDefiningOp());
639         if (!exempt_restore_unknown_rank &&
640             input_layout->second.rank() != ValueRank(value))
641           return op->emitOpError()
642                  << "Rank for input " << i << " layout is "
643                  << input_layout->second.rank() << " but actual rank is "
644                  << ValueRank(value);
645 
646         // If there was a layout returned and either no previous request or the
647         // request changed, insert and mark as updated.
648         consumer_request[operand] = input_layout->second;
649         is_updated.insert(value);
650       } else if (input_layout == new_input_layouts.end() &&
651                  consumer_request_from_op_operand != consumer_request.end()) {
652         // If no layout was returned and there is previous request, erase the
653         // old consumer request.
654         consumer_request.erase(operand);
655         is_updated.insert(value);
656       }
657     }
658   }
659 
660   // Update the producer request for this op.
661   // If the output is different from what is in the request list, update the
662   // the request and mark the mlir::Value as having an updated Layout request.
663   for (int i = 0; i < op->getNumResults(); ++i) {
664     const auto output_layout = new_output_layouts.find(i);
665     if (output_layout == new_output_layouts.end()) continue;
666     const auto& result = op->getOpResult(i);
667     if (producer_request[result] != output_layout->second) {
668       if (output_layout->second.rank() != ValueRank(result))
669         return op->emitOpError() << "Rank for output " << i << " layout is "
670                                  << output_layout->second.rank()
671                                  << " but actual rank is " << ValueRank(result);
672       producer_request[result] = output_layout->second;
673       is_updated.insert(result);
674     }
675   }
676   return mlir::success();
677 }
678 
InsertDTensorLayoutOps(mlir::OpBuilder & builder,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts)679 mlir::LogicalResult InsertDTensorLayoutOps(
680     mlir::OpBuilder& builder,
681     const llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
682   for (const auto& merged_layout : merged_layouts) {
683     // merged_layout is a pair of mlir::Value and Layout.
684     // If there is only one user of the Value and that user is a DTensorLayout
685     // op, then we can skip creating the op as the layout is already there. Note
686     // that we specifically do not allow updating a layout in an already present
687     // DTensorLayout op as we have considered them to be 'locked' throughout
688     // the algorithm.
689     const auto& users = merged_layout.first.getUsers();
690     int num_users = std::distance(users.begin(), users.end());
691     if (num_users == 1 && mlir::isa<mlir::TF::DTensorLayout>(*users.begin()))
692       continue;
693     builder.setInsertionPointAfterValue(merged_layout.first);
694     // Handles resource and variant as the real shape is embedded in the
695     // resource type elements.
696     mlir::Type value_type = GetSubtypeOrSelf(merged_layout.first);
697 
698     if (auto type = value_type.dyn_cast<mlir::TensorType>()) {
699       auto layout_op = builder.create<mlir::TF::DTensorLayout>(
700           merged_layout.first.getLoc(), merged_layout.first,
701           mlir::dtensor::LayoutAttr::get(builder.getContext(),
702                                          merged_layout.second),
703           mlir::TF::ShapeAttr::get(builder.getContext(), type));
704       llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
705       merged_layout.first.replaceAllUsesExcept(layout_op.output(), exception);
706     } else {
707       mlir::emitError(merged_layout.first.getLoc())
708           << "value type is not TensorType as expected.";
709     }
710   }
711 
712   return mlir::success();
713 }
714 
GetOperationsNeedingUpdate(const llvm::DenseSet<mlir::Value> & is_updated,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,llvm::DenseSet<mlir::Operation * > & operations)715 void GetOperationsNeedingUpdate(
716     const llvm::DenseSet<mlir::Value>& is_updated,
717     const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
718     llvm::DenseSet<mlir::Operation*>& operations) {
719   for (auto& value : is_updated) {
720     auto uses = consumers.find(value);
721     // Some values have no consumers (e.g. outputs of the main function).
722     if (uses != consumers.end())
723       for (auto* use : uses->second)
724         if (!mlir::isa<mlir::TF::CopyToMeshOp>(use->getOwner()))
725           operations.insert(use->getOwner());
726     // If this is an OpResult, also add the op that produces it.
727     if (value.isa<mlir::OpResult>() &&
728         !mlir::isa<mlir::TF::CopyToMeshOp>(value.getDefiningOp()))
729       operations.insert(value.getDefiningOp());
730   }
731 }
732 
733 namespace {
734 
735 // Custom printing class which prints out layouts and ignores DTensorLayout
736 // ops and also non registered attributes.
737 class LayoutPrinter : public mlir::OpAsmPrinter {
738  public:
LayoutPrinter(llvm::raw_ostream & os,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts)739   explicit LayoutPrinter(
740       llvm::raw_ostream& os,
741       const llvm::DenseMap<mlir::Value, Layout>& merged_layouts)
742       : indent_level_(0),
743         os_(os),
744         current_location_(0),
745         next_argument_(0),
746         merged_layouts_(merged_layouts) {}
747 
getStream() const748   llvm::raw_ostream& getStream() const override { return os_; }
749 
printRegionArgument(mlir::BlockArgument arg,llvm::ArrayRef<mlir::NamedAttribute> argAttrs,bool omitType)750   void printRegionArgument(mlir::BlockArgument arg,
751                            llvm::ArrayRef<mlir::NamedAttribute> argAttrs,
752                            bool omitType) override {
753     printOperand(arg);
754     if (!omitType) {
755       os_ << ": ";
756       printType(arg.getType());
757     }
758     printOptionalAttrDict(argAttrs, llvm::None);
759   }
760 
printOperand(mlir::Value value)761   void printOperand(mlir::Value value) override { printOperand(value, os_); }
762 
763   /// Print a newline and indent the printer to the start of the current
764   /// operation.
printNewline()765   void printNewline() override {
766     os_ << "\n";
767     os_.indent(indent_level_);
768   }
769 
770   // Note that we ignore the parameters printEntryBlockArgs and
771   // printBlockTerminators for simplicity.
printRegion(mlir::Region & blocks,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock=false)772   void printRegion(mlir::Region& blocks, bool printEntryBlockArgs,
773                    bool printBlockTerminators,
774                    bool printEmptyBlock = false) override {
775     os_ << " {\n";
776     for (auto& b : blocks.getBlocks()) print(b);
777     os_.indent(indent_level_) << "}";
778   }
779 
print(mlir::Block & block)780   void print(mlir::Block& block) {
781     // Each nested block level increases the indent.
782     os_.indent(indent_level_) << "%bb(";
783     for (int i = 0; i < block.getNumArguments(); ++i) {
784       if (arguments_.find(block.getArgument(i)) == arguments_.end())
785         arguments_[block.getArgument(i)] = next_argument_++;
786       if (i > 0) os_ << ", ";
787       os_ << "%arg" << arguments_[block.getArgument(i)];
788     }
789     os_ << "):\n";
790     indent_level_ += 2;
791     for (auto& op : block.getOperations()) print(op);
792     indent_level_ -= 2;
793   }
794 
795   // Prints the TF node name from `loc`.
printLoc(mlir::Location loc)796   void printLoc(mlir::Location loc) {
797     os_ << " [" << mlir::GetNameFromLoc(loc) << "]";
798   }
799 
print(mlir::Operation & op)800   void print(mlir::Operation& op) {
801     // Don't print tf.DTensorLayout ops.
802     if (mlir::isa<mlir::TF::DTensorLayout>(op)) return;
803 
804     // Don't print functions with empty bodies.
805     if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(op))
806       if (func_op.empty()) return;
807 
808     // Each operation is on its own line, so we start by indenting the
809     // the line.
810     os_.indent(indent_level_);
811 
812     // Record a unique identifier for the op (this will be used for printing
813     // op results and operands).
814     location_[&op] = current_location_++;
815 
816     // Print the outputs.
817     for (int i = 0; i < op.getNumResults(); ++i) {
818       if (i > 0) os_ << ", ";
819       printOperand(op.getOpResult(i), os_);
820     }
821 
822     if (op.getNumResults() > 0) os_ << " = ";
823 
824     // Some ops have a special printing method, call this if it exists.
825     if (auto opInfo = op.getRegisteredInfo()) {
826       opInfo->printAssembly(&op, *this, /*defaultDialect=*/"");
827       printLoc(op.getLoc());
828       os_ << "\n";
829       return;
830     }
831 
832     // Otherwise we do a generic printing.
833     printGenericOp(&op, true);
834     printLoc(op.getLoc());
835 
836     os_ << "\n";
837   }
838 
839   // Print an operand, this could be both the OpResult or a BlockArgument.
840   // We also print the layout if it exists and the type.
printOperand(mlir::Value value,llvm::raw_ostream & os)841   void printOperand(mlir::Value value, llvm::raw_ostream& os) override {
842     if (auto result = value.dyn_cast<mlir::OpResult>()) {
843       // If DTensorLayout ops are already in the module, we need to skip them
844       // since we aren't printing them out.
845       if (mlir::isa<mlir::TF::DTensorLayout>(result.getDefiningOp())) {
846         printOperand(result.getDefiningOp()->getOperand(0));
847         return;
848       }
849 
850       // OpResult are of the format %op_number:%result_number. We elide the
851       // result number if there is only one result (the case for most ops).
852       os << "%" << location_[result.getDefiningOp()];
853       if (result.getDefiningOp()->getNumResults() > 1)
854         os << ":" << result.getResultNumber();
855     } else if (auto argument = value.dyn_cast<mlir::BlockArgument>()) {
856       if (arguments_.find(argument) == arguments_.end())
857         arguments_[argument] = next_argument_++;
858       os << "%arg" << arguments_[argument];
859     }
860     auto layout = merged_layouts_.find(value);
861     if (layout != merged_layouts_.end()) {
862       os << " \"";
863       printLayout(layout->second, os);
864       os << "\"";
865     }
866     os << " ";
867     printType(value.getType());
868   }
869 
printLayout(const Layout & layout,llvm::raw_ostream & os)870   void printLayout(const Layout& layout, llvm::raw_ostream& os) {
871     // Layouts are printed with * for an unsharded dim and the mesh dim for a
872     // sharded dim. This keeps the layout compact.
873     for (int i = 0; i < layout.rank(); ++i) {
874       if (i > 0) os << ",";
875       if (Layout::IsUnshardedDimension(layout.sharding_spec(i)))
876         os << "*";
877       else
878         os << layout.sharding_spec(i);
879     }
880   }
881 
882   // A generic op consists of a name, and any of the following:
883   // * arguments,
884   // * attributes
885   // * regions
886   // These are printed out in that order.
printGenericOp(mlir::Operation * op,bool printOpName)887   void printGenericOp(mlir::Operation* op, bool printOpName) override {
888     if (printOpName) os_ << "\"" << op->getName().getStringRef() << "\"";
889     os_ << "(";
890     for (int i = 0; i < op->getNumOperands(); ++i) {
891       if (i > 0) os_ << ", ";
892       printOperand(op->getOperand(i), os_);
893     }
894     os_ << ")";
895 
896     if (!op->getAttrs().empty()) {
897       std::vector<mlir::NamedAttribute> filtered;
898       for (auto attr : op->getAttrs())
899         if (*attr.getName().str().begin() != '_' &&
900             attr.getName().str() != "device")
901           filtered.emplace_back(attr);
902       if (!filtered.empty()) {
903         os_ << " {";
904         for (int i = 0; i < filtered.size(); ++i) {
905           if (i > 0) os_ << ", ";
906           printNamedAttribute(filtered[i]);
907         }
908         os_ << "}";
909       }
910     }
911 
912     if (op->getNumRegions() > 0) {
913       os_ << " (";
914       for (auto& region : op->getRegions()) printRegion(region, false, false);
915       os_ << ")";
916     }
917   };
918 
printSymbolName(llvm::StringRef symbolRef)919   void printSymbolName(llvm::StringRef symbolRef) override {
920     os_ << symbolRef;
921   };
922 
printNamedAttribute(mlir::NamedAttribute attr)923   void printNamedAttribute(mlir::NamedAttribute attr) {
924     os_ << attr.getName().strref() << " = ";
925     printAttribute(attr.getValue());
926   }
927 
printAttribute(mlir::Attribute attr)928   void printAttribute(mlir::Attribute attr) override { attr.print(os_); }
929 
printType(mlir::Type type)930   void printType(mlir::Type type) override { type.print(os_); }
931 
932   // The following functions are part of the printing interface but aren't
933   // needed for the compact printing form for Layout printing.
printAttributeWithoutType(mlir::Attribute attr)934   void printAttributeWithoutType(mlir::Attribute attr) override{};
printSuccessor(mlir::Block * successor)935   void printSuccessor(mlir::Block* successor) override{};
printSuccessorAndUseList(mlir::Block * successor,mlir::ValueRange succOperands)936   void printSuccessorAndUseList(mlir::Block* successor,
937                                 mlir::ValueRange succOperands) override{};
printOptionalAttrDict(llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::ArrayRef<llvm::StringRef> elidedAttrs)938   void printOptionalAttrDict(
939       llvm::ArrayRef<mlir::NamedAttribute> attrs,
940       llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
printOptionalAttrDictWithKeyword(llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::ArrayRef<llvm::StringRef> elidedAttrs)941   void printOptionalAttrDictWithKeyword(
942       llvm::ArrayRef<mlir::NamedAttribute> attrs,
943       llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
944 
shadowRegionArgs(mlir::Region & region,mlir::ValueRange namesToUse)945   void shadowRegionArgs(mlir::Region& region,
946                         mlir::ValueRange namesToUse) override{};
printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr,mlir::ValueRange operands)947   void printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr,
948                               mlir::ValueRange operands) override{};
949 
printAffineExprOfSSAIds(mlir::AffineExpr expr,mlir::ValueRange dimOperands,mlir::ValueRange symOperands)950   void printAffineExprOfSSAIds(mlir::AffineExpr expr,
951                                mlir::ValueRange dimOperands,
952                                mlir::ValueRange symOperands) override{};
953 
954  private:
955   int indent_level_;
956   llvm::raw_ostream& os_;
957   llvm::DenseMap<mlir::Operation*, int> location_;
958   int current_location_;
959   llvm::DenseMap<mlir::BlockArgument, int> arguments_;
960   int next_argument_;
961   const llvm::DenseMap<mlir::Value, Layout>& merged_layouts_;
962 };
963 
964 // Log the current set of layouts to a file marked by the hash of the input
965 // module and the stage.
LogLayoutsAndOps(const int stage,const uint64_t module_hash,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,mlir::ModuleOp & module)966 void LogLayoutsAndOps(const int stage, const uint64_t module_hash,
967                       const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
968                       mlir::ModuleOp& module) {
969   if (module->hasAttr(kDoNotLog) || ((ClientId() != 0) && !LogOnAllTasks()))
970     return;
971 
972   std::string prefix = tensorflow::GetDumpDirFromEnvVar();
973   if (prefix.empty()) return;
974 
975   auto* env = tensorflow::Env::Default();
976   auto status = env->RecursivelyCreateDir(prefix);
977   if (!status.ok()) {
978     LOG(WARNING) << "cannot create directory '" + prefix +
979                         "': " + status.error_message();
980     return;
981   }
982 
983   absl::StrAppend(&prefix, "/layout_propagation_v2_module_", module_hash,
984                   "_stage_", stage, "_");
985   if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
986     LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
987     return;
988   }
989 
990   std::unique_ptr<WritableFile> file_writer;
991   status = env->NewWritableFile(prefix, &file_writer);
992   if (!status.ok()) {
993     LOG(WARNING) << "cannot open file '" + prefix +
994                         "': " + status.error_message();
995     return;
996   }
997 
998   // Print the module to a string before writing to the file.
999   std::string txt_module;
1000   {
1001     llvm::raw_string_ostream os(txt_module);
1002     LayoutPrinter printer(os, merged_layouts);
1003     module.print(printer);
1004   }
1005 
1006   status = file_writer->Append(txt_module);
1007   if (!status.ok()) {
1008     LOG(WARNING) << "error writing to file '" + prefix +
1009                         "': " + status.error_message();
1010     return;
1011   }
1012   (void)file_writer->Close();
1013   LOG(INFO) << "Dumped MLIR module to " << prefix;
1014 }
1015 
1016 // Canonicalizer and DCE transformation passes may removed ops in the graph and
1017 // result in multiple consecutive DTensorLayout ops. Detect all such cases and
1018 // replace unnecessary DTensorLayout ops with Identity ops.
ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(mlir::ModuleOp module)1019 mlir::LogicalResult ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(
1020     mlir::ModuleOp module) {
1021   llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
1022   module.walk([&](mlir::TF::DTensorLayout op) { layout_ops.emplace_back(op); });
1023 
1024   for (auto layout_op : llvm::reverse(layout_ops)) {
1025     auto input_op = layout_op.input().getDefiningOp();
1026     if (auto input_layout_op =
1027             llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(input_op)) {
1028       // Check that layout of input DTensorLayout op is equivalent to
1029       // the layout of its connected DTensorLayout op.
1030       if (layout_op.layout() != input_layout_op.layout())
1031         return layout_op.emitOpError(
1032             "Found inconsistent layout. This should never happen.");
1033 
1034       // Replace DTensorLayout op with identity op.
1035       mlir::OpBuilder builder(layout_op);
1036       auto identity = builder.create<mlir::TF::IdentityOp>(
1037           layout_op->getLoc(), layout_op.getType(), layout_op.input());
1038       layout_op.output().replaceAllUsesWith(identity.output());
1039       layout_op.erase();
1040     }
1041   }
1042 
1043   return mlir::success();
1044 }
1045 
1046 // Inserts/changes DTensorLayout op after IfRegion op and results of then/else
1047 // branches to ensure that the return values of IfRegion ops are consistent.
1048 // After layout propagation, layouts of return value of tf.IfRegion op, and
1049 // layouts of terminators of then/else branches of IfRegion op may be different.
1050 // In that case, the layouts of returns values must be merged to a same layout
1051 // as return values of IfRegion op and results of then/else branches are
1052 // semantically equivalent.
InsertDTensorLayoutForIfRegionOp(const llvm::SmallVectorImpl<mlir::TF::IfRegionOp> & if_ops,mlir::MLIRContext * context)1053 mlir::LogicalResult InsertDTensorLayoutForIfRegionOp(
1054     const llvm::SmallVectorImpl<mlir::TF::IfRegionOp>& if_ops,
1055     mlir::MLIRContext* context) {
1056   for (mlir::TF::IfRegionOp if_op : if_ops) {
1057     for (mlir::OpResult if_result : if_op.getResults()) {
1058       const int result_index = if_result.getResultNumber();
1059       mlir::Value then_branch_result = if_op.then_branch()
1060                                            .front()
1061                                            .getTerminator()
1062                                            ->getOpOperand(result_index)
1063                                            .get();
1064       mlir::Value else_branch_result = if_op.else_branch()
1065                                            .front()
1066                                            .getTerminator()
1067                                            ->getOpOperand(result_index)
1068                                            .get();
1069 
1070       auto if_result_layout =
1071           llvm::dyn_cast<mlir::TF::DTensorLayout>(*if_result.user_begin());
1072       auto then_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1073           *then_branch_result.getDefiningOp());
1074       auto else_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1075           *else_branch_result.getDefiningOp());
1076       llvm::SmallVector<Layout, 4> layouts{if_result_layout.layout(),
1077                                            then_result_layout.layout(),
1078                                            else_result_layout.layout()};
1079       std::set<Layout> layouts_set{layouts.begin(), layouts.end()};
1080       if (layouts_set.size() == 1) continue;
1081 
1082       absl::optional<Layout> merged_layout;
1083       if (mlir::failed(
1084               GetMostShardedLayout(layouts, if_op.getLoc(), &merged_layout)))
1085         return mlir::failure();
1086       assert(merged_layout);
1087 
1088       if_result_layout->setAttr(
1089           kQualifiedLayoutAttr,
1090           mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1091       then_result_layout->setAttr(
1092           kQualifiedLayoutAttr,
1093           mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1094       else_result_layout->setAttr(
1095           kQualifiedLayoutAttr,
1096           mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1097     }
1098   }
1099   return mlir::success();
1100 }
1101 
1102 // Inserts necessary DTensorRelayout ops so that the layouts for while loops
1103 // are correct.
1104 //
1105 // Due to how while loop layout propagation is done, we may need to fix the
1106 // layouts so that the second and beyond step of the loop receive a tensor with
1107 // the correct layout.
1108 // E.g.
1109 // %b = tf.WhileRegion(%a) ({
1110 //     %bb0(%arg0):  # Cond
1111 //       %c = tf.A(%arg0)
1112 //       tf.Yield(%c)
1113 //     }, {
1114 //     %bb0(%arg0):  # Body
1115 //       %d = tf.B(%arg0)
1116 //       tf.Yield(%d)
1117 //     }
1118 //   }
1119 // %e = tf.C(%b)
1120 //
1121 // Layout propagation treats the loop body as if it were an inlined function and
1122 // does not have a condition which fixes the layout of %d, as return value, to
1123 // match the layout of %arg0 (or %a).
1124 //
1125 // Towards this, we:
1126 // 1) Check the layout of %arg0 and see if matches the layout of the input 0
1127 //    (%d) of tf.Yield.
1128 // 2) If it doesn't match we update the we insert a DTensorRelayout op between
1129 //    %d and tf.Yield with the correct layout and insert a second
1130 //    DTensorRelayout op after the loop body.
1131 //
1132 // NOTE: that it is necessary in general to insert both DTensorRelayout ops,
1133 // as opposed to just updating the layout of %d (which would in general be more
1134 // efficient) since %d may still be used by other ops in the loop body.
1135 //
1136 // NOTE: this is not needed for the condition as the output of the condition is
1137 // a scalar and therefore always replicated.
InsertRelayoutForWhileLoops(const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> & while_ops,mlir::OpBuilder & builder)1138 mlir::LogicalResult InsertRelayoutForWhileLoops(
1139     const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
1140     mlir::OpBuilder& builder) {
1141   for (mlir::TF::WhileRegionOp op : while_ops) {
1142     // Get the terminator so we can check the output layouts of the loop body.
1143     mlir::Operation* yield_op = op.body().front().getTerminator();
1144     if (!mlir::isa<mlir::TF::YieldOp>(yield_op))
1145       return op->emitOpError() << "body terminator is not a Yield op.";
1146 
1147     for (int i = 0; i < op.body().getNumArguments(); ++i) {
1148       // Inputs should only have one, a DTensorLayout op.
1149       mlir::Value argument = op.body().getArgument(i);
1150       if (!argument.hasOneUse())
1151         return op.emitOpError()
1152                << "body argument " << i << " doesn't have a single use.";
1153       mlir::Operation* input_layout_op = argument.getUses().begin().getUser();
1154       if (!mlir::isa<mlir::TF::DTensorLayout>(input_layout_op))
1155         return op.emitOpError() << "body argument " << i
1156                                 << " is not consumed by a DTensorLayout op.";
1157       const Layout input_layout =
1158           mlir::cast<mlir::TF::DTensorLayout>(input_layout_op).layout();
1159 
1160       // Inputs to Yield should also be a DTensorLayout op.
1161       if (!yield_op->getOperand(i).isa<mlir::OpResult>() ||
1162           !mlir::isa<mlir::TF::DTensorLayout>(
1163               yield_op->getOperand(i).getDefiningOp()))
1164         return yield_op->emitOpError()
1165                << "argument " << i << " to is not a DTensorLayout op.";
1166       mlir::Operation* output_layout_op =
1167           yield_op->getOperand(i).getDefiningOp();
1168       const Layout output_layout =
1169           mlir::cast<mlir::TF::DTensorLayout>(output_layout_op).layout();
1170 
1171       // If the layouts are equal we have nothing to do. Note that this caches
1172       // the case that that input and output are a resource, since the layout
1173       // of a resource is fixed.
1174       if (input_layout == output_layout) continue;
1175 
1176       // Insert the first Relayout op (in the loop body).
1177       builder.setInsertionPointAfter(output_layout_op);
1178       if (!yield_op->getOperand(i).getType().isa<mlir::TensorType>())
1179         return yield_op->emitOpError()
1180                << "operand " << i << " does not have TensorType";
1181       mlir::TF::ShapeAttr global_shape = mlir::TF::ShapeAttr::get(
1182           builder.getContext(),
1183           yield_op->getOperand(i).getType().cast<mlir::TensorType>());
1184       mlir::TF::RelayoutOp first_relayout =
1185           builder.create<mlir::TF::RelayoutOp>(
1186               op.getLoc(), yield_op->getOperand(i).getType(),
1187               yield_op->getOperand(i), input_layout.ToString());
1188       mlir::TF::DTensorLayout first_layout_op =
1189           builder.create<mlir::TF::DTensorLayout>(
1190               op.getLoc(), first_relayout.output(),
1191               mlir::dtensor::LayoutAttr::get(builder.getContext(),
1192                                              input_layout),
1193               global_shape);
1194       yield_op->setOperand(i, first_layout_op.output());
1195 
1196       // Insert the second relayout op after the loop itself.
1197       builder.setInsertionPointAfter(op);
1198       mlir::TF::DTensorLayout second_layout_op =
1199           builder.create<mlir::TF::DTensorLayout>(
1200               op.getLoc(), op->getResult(i),
1201               mlir::dtensor::LayoutAttr::get(builder.getContext(),
1202                                              input_layout),
1203               global_shape);
1204       mlir::TF::RelayoutOp second_relayout =
1205           builder.create<mlir::TF::RelayoutOp>(
1206               op.getLoc(), second_layout_op.output().getType(),
1207               second_layout_op.output(), output_layout.ToString());
1208       op->getResult(i).replaceAllUsesExcept(
1209           second_relayout.output(), llvm::SmallPtrSet<mlir::Operation*, 1>{
1210                                         second_layout_op.getOperation()});
1211     }
1212   }
1213   return mlir::success();
1214 }
1215 
1216 // For all constants with multiple usages, clone the constants so that each
1217 // constant operation has at most 1 usage.
DuplicateConstants(mlir::ModuleOp module)1218 void DuplicateConstants(mlir::ModuleOp module) {
1219   llvm::SmallVector<mlir::TF::ConstOp, 4> const_ops;
1220   module.walk(
1221       [&](mlir::TF::ConstOp const_op) { const_ops.emplace_back(const_op); });
1222 
1223   for (mlir::TF::ConstOp const_op : const_ops) {
1224     mlir::OpBuilder builder(const_op);
1225     auto uses = const_op->getUses();
1226     if (uses.empty()) return;
1227 
1228     llvm::SmallDenseMap<mlir::Operation*, mlir::OpOperand*> const_use_map;
1229     mlir::OpOperand& first_use = *uses.begin();
1230     for (mlir::OpOperand& use : uses) {
1231       if (&use == &first_use) continue;
1232 
1233       mlir::Operation* new_const = builder.clone(*const_op);
1234       const_use_map.try_emplace(new_const, &use);
1235     }
1236 
1237     for (const auto& it : const_use_map) it.second->set(it.first->getResult(0));
1238   }
1239 }
1240 
1241 // Find the root(s) values of "current_value" within the cycle, and put it
1242 // into "roots".
FindRoot(const llvm::DenseSet<mlir::Value> & is_updated,const mlir::Value & current_value,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseSet<mlir::Value> * roots)1243 void FindRoot(
1244     const llvm::DenseSet<mlir::Value>& is_updated,
1245     const mlir::Value& current_value,
1246     llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1247     llvm::DenseSet<mlir::Value>* roots) {
1248   // Standard BFS to find root values of current_value.
1249   std::deque<mlir::Value> to_process;
1250   to_process.push_back(current_value);
1251 
1252   llvm::DenseSet<mlir::Value> visited;
1253   visited.insert(current_value);
1254 
1255   while (!to_process.empty()) {
1256     int level_size = to_process.size();
1257     for (int UNUSED = 0; UNUSED < level_size; ++UNUSED) {
1258       mlir::Value cur_val = to_process.front();
1259       to_process.pop_front();
1260 
1261       // Terminating condition, if there is no defining op then this is a root.
1262       mlir::Operation* defining_op = cur_val.getDefiningOp();
1263       if (defining_op == nullptr) {
1264         roots->insert(current_value);
1265         continue;
1266       }
1267 
1268       // Expand out from 'cur_val' one step closer to roots. If there was
1269       // no-one one step closer to root, then this is a root.
1270       bool is_root = true;
1271       for (int i = 0; i < defining_op->getNumOperands(); ++i) {
1272         mlir::Value operand = defining_op->getOperand(i);
1273         if (operand != cur_val && is_updated.contains(operand)) {
1274           is_root = false;
1275 
1276           if (!visited.contains(operand)) {
1277             visited.insert(operand);
1278             to_process.push_back(operand);
1279           }
1280         }
1281       }
1282 
1283       if (is_root) roots->insert(cur_val);
1284     }
1285   }
1286 }
1287 
1288 // Finds the root value(s) of the values that have layouts cycling back and
1289 // forth in an infinite loop during layout propagation and prints the closest TF
1290 // op that consumes those root value(s). This allows users and developers to
1291 // debug the root cause of layouts that should be changed to prevent infinite
1292 // layout propagation.
FindRootsAndEmitError(mlir::ModuleOp & module,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> producers,const llvm::DenseSet<mlir::Value> & is_updated)1293 void FindRootsAndEmitError(
1294     mlir::ModuleOp& module,
1295     llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers,
1296     const llvm::DenseSet<mlir::Value>& is_updated) {
1297   llvm::DenseSet<mlir::Value> roots;
1298   for (auto& value : is_updated) {
1299     FindRoot(is_updated, value, producers, &roots);
1300   }
1301   module.emitOpError()
1302       << "Maximum number of layout propagation steps reached. Unable to "
1303          "converge to a fixed layout. Please rerun with a higher limit in the "
1304          "DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS environment variable.";
1305   for (auto& root : roots) {
1306     for (mlir::OpOperand& operand : root.getUses()) {
1307       llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
1308       llvm::SmallVector<mlir::Value, 4> skipped_values;
1309 
1310       // For each root value that may need a different layout, find the
1311       // closest TF op that consumes it and print it.
1312       llvm::SmallVector<mlir::OpOperand*, 4> consuming_operands =
1313           TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
1314 
1315       for (mlir::OpOperand* new_operand : consuming_operands) {
1316         mlir::Operation* operation = new_operand->getOwner();
1317         mlir::Location loc = operation->getLoc();
1318         operation->emitOpError() << '\n'
1319                                  << "The following op consumes tensors that "
1320                                     "may need a different layout. "
1321                                     "["
1322                                  << mlir::GetNameFromLoc(loc) << "]" << '\n';
1323       }
1324     }
1325   }
1326 }
1327 }  // namespace
1328 
1329 // Runs an iteration of layout propagation, where we merge producer and consumer
1330 // requests and then recompute recommended layouts on all operations that
1331 // are connected to an updated layout.
RunOneIteration(llvm::DenseSet<mlir::Value> & is_locked,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,llvm::DenseMap<mlir::Value,Layout> & merged_layouts,mlir::ModuleOp & module,const uint64_t module_hash,int * stage)1332 Status RunOneIteration(
1333     llvm::DenseSet<mlir::Value>& is_locked,
1334     llvm::DenseSet<mlir::Value>& is_updated,
1335     llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
1336     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
1337         consumer_requests,
1338     llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1339     llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
1340     llvm::DenseMap<mlir::Value, Layout>& merged_layouts, mlir::ModuleOp& module,
1341     const uint64_t module_hash, int* stage) {
1342   if (is_updated.empty()) return Status::OK();
1343   // Merge any possibly updated layouts.
1344   if (mlir::failed(
1345           MergeAndGetUpdatedLayouts(is_locked, is_updated, producer_request,
1346                                     consumer_requests, merged_layouts)))
1347     return errors::Internal(
1348         "MergeAndGetUpdatedLayouts failed to merge layouts.");
1349 
1350   // Compile a list of operations with updated inputs or outputs.
1351   llvm::DenseSet<mlir::Operation*> operations_needing_update;
1352   GetOperationsNeedingUpdate(is_updated, consumers, operations_needing_update);
1353   is_updated.clear();
1354 
1355   if (VLOG_IS_ON(2)) {
1356     LogLayoutsAndOps(*stage, module_hash, merged_layouts, module);
1357   }
1358 
1359   for (auto* op : operations_needing_update) {
1360     if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1361                                         producer_request, consumer_requests,
1362                                         is_updated)))
1363       return errors::Internal("UpdateLayoutsForOp failed to update layouts.");
1364   }
1365   ++(*stage);
1366   return Status::OK();
1367 }
1368 
1369 // Compares every value's layouts in `merged_a` with the ones in `merged_b`,
1370 // and store the values that differ in `changed`.
CompareMergedLayouts(const llvm::DenseMap<mlir::Value,Layout> & merged_a,const llvm::DenseMap<mlir::Value,Layout> & merged_b,llvm::DenseSet<mlir::Value> & changed)1371 Status CompareMergedLayouts(const llvm::DenseMap<mlir::Value, Layout>& merged_a,
1372                             const llvm::DenseMap<mlir::Value, Layout>& merged_b,
1373                             llvm::DenseSet<mlir::Value>& changed) {
1374   if (merged_a.size() != merged_b.size())
1375     return errors::Internal(
1376         "Both merged_layouts did not have the same number of set layouts.");
1377   for (const auto& value_and_layout : merged_a) {
1378     const mlir::Value value = value_and_layout.getFirst();
1379     const Layout& layout = value_and_layout.getSecond();
1380     auto value_and_layout_in_b = merged_b.find(value);
1381     if (value_and_layout_in_b == merged_b.end())
1382       return errors::Internal(
1383           "Comparing merged_layouts that contain different mlir::Value's.");
1384     if (value_and_layout_in_b->second != layout) {
1385       changed.insert(value);
1386     }
1387   }
1388   return Status::OK();
1389 }
1390 
1391 // MLIR pass that propagates layout for all ops the module.
1392 struct DLayoutPropagationPassV2
1393     : public DTensorLayoutPropagationV2Base<DLayoutPropagationPassV2> {
getDependentDialectstensorflow::dtensor::DLayoutPropagationPassV21394   void getDependentDialects(mlir::DialectRegistry& registry) const override {
1395     registry.insert<mlir::dtensor::DTensorDialect>();
1396   }
1397 
runOnOperationtensorflow::dtensor::DLayoutPropagationPassV21398   void runOnOperation() override {
1399     mlir::MLIRContext& context = getContext();
1400     mlir::OpBuilder builder(&context);
1401 
1402     auto module = getOperation();
1403 
1404     if (mlir::failed(ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(module)))
1405       return signalPassFailure();
1406 
1407     // In order to ensure that constant operations with multiple usages with
1408     // different consumer layout requests does not lead to replicated constant
1409     // tensors, we duplicate all constants to have at most 1 usages.
1410     // After SPMD Expansion, these duplicated constants will be merged back
1411     // during SCCP pass.
1412     DuplicateConstants(module);
1413 
1414     mlir::func::FuncOp main_func =
1415         module.lookupSymbol<mlir::func::FuncOp>("main");
1416     if (!main_func) return;
1417 
1418     mlir::Dialect* tf_dialect =
1419         context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
1420 
1421     // This maps from OpResults to a list of OpOperands that consume this.
1422     // Note that this will pass over/through
1423     // (Stateful)PartitionedCall and other control flow, directly connecting
1424     // producing ops to their consumers in the function. I.e. it presents
1425     // flattened/inlined view of the flow of data.
1426     llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
1427     // Maintain a reverse mapping.
1428     llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
1429     // For each mlir::Value this is what the producer would like to have the
1430     // layout be.
1431     llvm::DenseMap<mlir::Value, absl::optional<Layout>> producer_request;
1432     // For each mlir::Value this is what the consumers would like to have the
1433     // layout be. Note the map is in 'parallel' to the consumers map above.
1434     llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>
1435         consumer_requests;
1436     // Tracks if the layout was updated since last cycle.
1437     llvm::DenseSet<mlir::Value> is_updated;
1438     // Tracks if the layout is locked. In this case we don't pass consumer
1439     // layouts to MergeLayouts. Used for input layouts and user defined layouts.
1440     llvm::DenseSet<mlir::Value> is_locked;
1441 
1442     // Create consumers and producers maps.
1443     if (mlir::failed(
1444             PopulateConsumersFromModule(&module, tf_dialect, consumers)))
1445       return signalPassFailure();
1446 
1447     for (auto& consumer : consumers) {
1448       for (auto* operand : consumer.second) {
1449         if (producers.find(operand) == producers.end()) {
1450           producers[operand] = std::vector<mlir::Value>{consumer.first};
1451         } else {
1452           producers[operand].emplace_back(consumer.first);
1453         }
1454       }
1455     }
1456 
1457     // Setup the initial starting conditions for the layout algorithm
1458     if (mlir::failed(InsertInitialLayouts(
1459             module, main_func, consumers, producers, consumer_requests,
1460             producer_request, is_updated, is_locked)))
1461       return signalPassFailure();
1462 
1463     const auto module_hash = OpHash(module);
1464     int stage = 0;
1465 
1466     llvm::DenseMap<mlir::Value, Layout> merged_layouts;
1467     Status status;
1468 
1469     while (!is_updated.empty() && stage < kLayoutPropagationMaxStages) {
1470       ++stage;
1471       int steps = 0;
1472       // Step 1. Run the layout propagation v2 until convergence or max steps.
1473       while (!is_updated.empty() && steps < LayoutPropagationMaxSteps()) {
1474         Status status = RunOneIteration(
1475             is_locked, is_updated, producer_request, consumer_requests,
1476             producers, consumers, merged_layouts, module, module_hash, &steps);
1477         if (!status.ok()) {
1478           module.emitOpError() << "Failure running iteration.";
1479           return signalPassFailure();
1480         }
1481       }
1482       if (VLOG_IS_ON(2)) {
1483         LOG(INFO) << "Failed to converge in stage " << stage;
1484       }
1485       // Step 2. If we are here, then we failed to converge, and likely
1486       // there is an oscillation of layouts. Detect all the edges that are
1487       // changing layouts.
1488       llvm::DenseMap<mlir::Value, Layout> merged_layouts_at_max_steps =
1489           merged_layouts;
1490       llvm::DenseSet<mlir::Value> changed;
1491       int previous_change_size = -1;
1492 
1493       while (changed.size() > previous_change_size) {
1494         if (!RunOneIteration(is_locked, is_updated, producer_request,
1495                              consumer_requests, producers, consumers,
1496                              merged_layouts, module, module_hash, &steps)
1497                  .ok()) {
1498           module.emitOpError() << "Failure running iteration.";
1499           return signalPassFailure();
1500         }
1501         if (!CompareMergedLayouts(merged_layouts_at_max_steps, merged_layouts,
1502                                   changed)
1503                  .ok()) {
1504           module.emitOpError() << "Failure comparing merged layouts.";
1505           return signalPassFailure();
1506         }
1507         previous_change_size = changed.size();
1508       }
1509 
1510       // Step 3. Layouts that haven't changed means they're not part of the
1511       // cyclic problem, so freeze them.
1512       for (const auto& value_and_layout : merged_layouts) {
1513         const mlir::Value value = value_and_layout.getFirst();
1514         if (changed.find(value) == changed.end()) {
1515           is_locked.insert(value);
1516         }
1517       }
1518       // Step 4. Any information corresponding to the changed layouts
1519       // should be disinfected, we do this by clearing all information
1520       // regarding them.
1521       for (const mlir::Value changed_value : changed) {
1522         producer_request.erase(changed_value);
1523         consumer_requests.erase(changed_value);
1524         merged_layouts.erase(changed_value);
1525       }
1526 
1527       // Step 5. ComputeLayout again on all the ops linked to the changed
1528       // layouts. The next iteration will take this information and merge again.
1529       llvm::DenseSet<mlir::Operation*> operations_needing_update;
1530       is_updated = changed;
1531       GetOperationsNeedingUpdate(is_updated, consumers,
1532                                  operations_needing_update);
1533       is_updated.clear();
1534 
1535       for (auto* op : operations_needing_update) {
1536         if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1537                                             producer_request, consumer_requests,
1538                                             is_updated))) {
1539           module.emitOpError() << "Failure in UpdateLayoutsForOp.";
1540           return signalPassFailure();
1541         }
1542       }
1543     }
1544 
1545     if (stage >= kLayoutPropagationMaxStages) {
1546       FindRootsAndEmitError(module, producers, is_updated);
1547       return signalPassFailure();
1548     }
1549 
1550     if (mlir::failed(
1551             CopyLayoutsForSkippedOps(module, tf_dialect, merged_layouts)))
1552       return signalPassFailure();
1553 
1554     if (VLOG_IS_ON(2)) {
1555       LogLayoutsAndOps(stage, module_hash, merged_layouts, module);
1556     }
1557 
1558     if (!AllOpResultsHaveLayouts(&module, tf_dialect, merged_layouts))
1559       return signalPassFailure();
1560 
1561     if (mlir::failed(InsertDTensorLayoutOps(builder, merged_layouts)))
1562       return signalPassFailure();
1563 
1564     // Handle layout of control flow operations.
1565     llvm::SmallVector<mlir::TF::IfRegionOp, 4> if_ops;
1566     llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
1567     module.walk([&](mlir::Operation* op) {
1568       if (auto if_op = llvm::dyn_cast<mlir::TF::IfRegionOp>(op))
1569         if_ops.emplace_back(if_op);
1570       else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
1571         while_ops.emplace_back(while_op);
1572     });
1573 
1574     if (mlir::failed(InsertRelayoutForWhileLoops(while_ops, builder)))
1575       return signalPassFailure();
1576 
1577     if (mlir::failed(
1578             InsertDTensorLayoutForIfRegionOp(if_ops, builder.getContext())))
1579       return signalPassFailure();
1580   };
1581 };
1582 
1583 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorLayoutPropagationPassV2()1584 CreateDTensorLayoutPropagationPassV2() {
1585   return std::make_unique<DLayoutPropagationPassV2>();
1586 }
1587 
1588 }  // namespace dtensor
1589 }  // namespace tensorflow
1590