1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <queue>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/OpImplementation.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
38 #include "mlir/IR/Types.h" // from @llvm-project
39 #include "mlir/IR/Value.h" // from @llvm-project
40 #include "mlir/IR/Visitors.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
46 #include "tensorflow/compiler/mlir/utils/name_utils.h"
47 #include "tensorflow/dtensor/cc/constants.h"
48 #include "tensorflow/dtensor/cc/dtensor_utils.h"
49 #include "tensorflow/dtensor/cc/tensor_layout.h"
50 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
51 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
52 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
53 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
54 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
55 #include "tensorflow/dtensor/mlir/layout_parsing.h"
56 #include "tensorflow/dtensor/mlir/op_utils.h"
57 #include "tensorflow/dtensor/mlir/spmd_expander.h"
58 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
59 #include "tensorflow/dtensor/mlir/value_utils.h"
60
61 namespace tensorflow {
62 namespace dtensor {
63
64 // This value dictates how many times during layout propagation we allow
65 // fixing of oscillatory behaviors.
66 constexpr int kLayoutPropagationMaxStages = 3;
67
AllOpResultsHaveLayouts(mlir::ModuleOp * module,mlir::Dialect * tf_dialect,const llvm::DenseMap<mlir::Value,Layout> & layouts)68 bool AllOpResultsHaveLayouts(
69 mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
70 const llvm::DenseMap<mlir::Value, Layout>& layouts) {
71 const auto& result = module->walk([&](mlir::Operation* op) {
72 if (op->getDialect() != tf_dialect ||
73 mlir::isa<mlir::TF::DTensorLayout>(op))
74 return mlir::WalkResult::advance();
75 for (const auto& result : op->getOpResults()) {
76 if (layouts.find(result) == layouts.end()) {
77 op->emitOpError() << "missing layout for result "
78 << result.getResultNumber();
79 return mlir::WalkResult::interrupt();
80 }
81 }
82 return mlir::WalkResult::advance();
83 });
84 return !result.wasInterrupted();
85 }
86
UpdateLayoutForSkippedOps(mlir::OpOperand & operand,const llvm::DenseMap<llvm::StringRef,mlir::Operation * > & func_to_caller,const Layout & layout_to_copy,llvm::DenseMap<mlir::Value,Layout> & layouts)87 void UpdateLayoutForSkippedOps(
88 mlir::OpOperand& operand,
89 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
90 const Layout& layout_to_copy,
91 llvm::DenseMap<mlir::Value, Layout>& layouts) {
92 llvm::SmallVector<mlir::Value, 4> skipped_values;
93 TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
94 for (const mlir::Value& skipped_value : skipped_values)
95 if ((!skipped_value.isa<mlir::OpResult>() ||
96 !mlir::isa<mlir::TF::DTensorLayout, mlir::tf_device::ClusterOp>(
97 skipped_value.getDefiningOp())) &&
98 layouts.find(skipped_value) == layouts.end())
99 // TraceUseToNextTFOp's contract is that it only skips over ops that
100 // act like the identity (such as function calls, returns, yields,
101 // controlflow, DTensorLayouts, etc). This means that operand layout
102 // that we came from is the layout we want for this value.
103 layouts[skipped_value] = layout_to_copy;
104 }
105
106 // Some ops, which are skipped by TraceUseToNextTFOp, will not have layouts
107 // for their mlir::OpResults.
108 // E.g. during the creation of the consumers map, we skip the input and output
109 // of the WhileRegion op. In particular if we have:
110 //
111 // %b = tf.WhileRegion(%a) ({
112 // %bb0(%arg0): # Cond
113 // %c = tf.A(%arg0)
114 // tf.Yield(%c)
115 // }, {
116 // %bb0(%arg0): # Body
117 // %d = tf.B(%arg0)
118 // tf.Yield(%d)
119 // }
120 // }
121 // %e = tf.C(%b)
122 //
123 // Then the consumers map would directly connect the mlir::Value %a to input 0
124 // of tf.A and tf.B, bypassing the WhileRegion and the mlir::Value of %arg1.
125 // Similarly it would connect the mlir::Value of %d directly to input 0 of
126 // tf.C bypassing the mlir::Value of %b.
127 // This means that at the end of layout propagation the skipped values would not
128 // have an assigned layout. But this layout can be derived by taking the known
129 // layout of %a and propagating to each mlir::Value that was skipped while
130 // connecting %a to the input 0 of tf.A and tf.B. Similarly we derive the layout
131 // for %b from %d.
132 //
133 // To get layouts we
134 // 1) Iterate over all ops that have layouts for their OpResults and call
135 // TraceUseToNextTFOp to get the skipped mlir::Values.
136 // 2) If any skipped mlir::Value doesn't have a layout set, then we set the
137 // layout.
CopyLayoutsForSkippedOps(mlir::ModuleOp module,mlir::Dialect * tf_dialect,llvm::DenseMap<mlir::Value,Layout> & layouts)138 mlir::LogicalResult CopyLayoutsForSkippedOps(
139 mlir::ModuleOp module, mlir::Dialect* tf_dialect,
140 llvm::DenseMap<mlir::Value, Layout>& layouts) {
141 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
142
143 if (mlir::failed(GetFuncToCaller(module, func_to_caller)))
144 return mlir::failure();
145
146 // Update layouts derived from ops.
147 module->walk([&](mlir::Operation* op) {
148 for (mlir::OpOperand& operand : op->getOpOperands()) {
149 if (layouts.find(operand.get()) == layouts.end()) continue;
150 const Layout layout = layouts[operand.get()];
151 UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
152 }
153 });
154
155 // Update layouts derived from inputs
156 mlir::func::FuncOp main_func =
157 module.lookupSymbol<mlir::func::FuncOp>("main");
158 if (!main_func) return mlir::success();
159
160 for (auto& value : main_func.getArguments()) {
161 if (layouts.find(value) == layouts.end()) continue;
162 const Layout layout = layouts[value];
163
164 for (mlir::OpOperand& operand : value.getUses())
165 UpdateLayoutForSkippedOps(operand, func_to_caller, layout, layouts);
166 }
167
168 return mlir::success();
169 }
170
171 namespace {
FilterkAnySpecs(std::vector<std::string> & proposed_specs)172 void FilterkAnySpecs(std::vector<std::string>& proposed_specs) {
173 for (auto& spec : proposed_specs) {
174 if (spec == Layout::kAny) spec = Layout::kUnshardedDim;
175 }
176 }
177 } // namespace
178
179 // Merges the producer and consumer layouts into a single layout.
180 // Assumes that all layouts are of the same rank.
181 // Consumers are first merged together so that we have the layout which is
182 // sharded in a tensor dim if and only if all consumers are sharded in the same
183 // sharding_spec.
184 // If producer layout is present, we merge the consumer layouts into the layout
185 // of the producer: if the consumer wants a sharded layout in a tensor dimension
186 // where the producer is unshared *and* the mesh dimension it wants to be
187 // sharded over is not already sharded over by the producer, then we add that
188 // sharding to the producer layout.
MergeLayouts(const absl::optional<Layout> & producer,const mlir::DenseMap<mlir::OpOperand *,Layout> & consumers)189 StatusOr<Layout> MergeLayouts(
190 const absl::optional<Layout>& producer,
191 const mlir::DenseMap<mlir::OpOperand*, Layout>& consumers) {
192 if (consumers.empty()) return producer.value();
193
194 // Initialize the specs to those of the first consumer layout and merge
195 // consumers.
196 std::vector<std::string> proposed_specs =
197 consumers.begin()->second.sharding_spec_strs();
198 int layout_rank = proposed_specs.size();
199
200 // Verify consumer layout ranks match.
201 for (const auto& consumer : consumers) {
202 const Layout& consumer_layout = consumer.second;
203 if (consumer_layout.rank() != layout_rank)
204 return errors::InvalidArgument(
205 "found two consumer layout of different ranks: ",
206 consumer_layout.rank(), " and ", layout_rank);
207 }
208
209 // Merge consumer layouts.
210 for (const auto& consumer : consumers) {
211 const Layout& consumer_layout = consumer.second;
212
213 // Check every tensor dimension.
214 for (int j = 0; j < consumer_layout.rank(); ++j) {
215 const std::string& consumer_spec_j = consumer_layout.sharding_spec(j);
216 if (consumer_spec_j == Layout::kAny) continue;
217
218 // If the proposed spec is set as any, give priority to the consumer spec.
219 if (proposed_specs[j] == Layout::kAny) {
220 proposed_specs[j] = consumer_spec_j;
221 continue;
222 }
223
224 // If any consumer layout disagrees with the current merge, set the
225 // spec to not sharded.
226 if (proposed_specs[j] != consumer_spec_j)
227 proposed_specs[j] = Layout::kUnshardedDim;
228 }
229 }
230
231 // Filter over-sharded specs.
232 absl::flat_hash_map<std::string, int> counter;
233 for (const std::string& spec : proposed_specs) counter[spec] += 1;
234 for (std::string& spec : proposed_specs)
235 if (counter[spec] > 1) spec = Layout::kUnshardedDim;
236
237 // Return layout if there is no producer, else move into producer algorithm.
238 const Mesh mesh = consumers.begin()->second.mesh();
239 if (!producer) {
240 FilterkAnySpecs(proposed_specs);
241 return Layout::GetLayout(proposed_specs, mesh);
242 }
243
244 if (producer->rank() != layout_rank) {
245 return errors::InvalidArgument(
246 "producer and consumer layout have different ranks: ", producer->rank(),
247 " and ", layout_rank);
248 }
249
250 // For the producer merge, first we define mesh dims used by the producer to
251 // avoid creating a layout that shards twice over the same mesh dim.
252 llvm::DenseSet<llvm::StringRef> producer_dims;
253 for (int j = 0; j < producer->rank(); ++j) {
254 llvm::StringRef spec = producer->sharding_spec(j);
255 if (Layout::IsShardedDimension(spec.str())) producer_dims.insert(spec);
256 }
257 // Merge producer layout with existing layout.
258 for (int j = 0; j < producer->rank(); ++j) {
259 const std::string& producer_spec = producer->sharding_spec(j);
260
261 if (producer_spec == proposed_specs[j] || producer_spec == Layout::kAny)
262 continue;
263
264 if (proposed_specs[j] == Layout::kAny) {
265 proposed_specs[j] = producer_spec;
266 continue;
267 }
268 // If producer is unsharded and proposed_spec is sharded. Need to make sure
269 // mesh dim is not used elsewhere. If so, set to unsharded.
270 if (Layout::IsUnshardedDimension(producer_spec)) {
271 bool isMeshDimUsed = producer_dims.contains(proposed_specs[j]);
272 if (isMeshDimUsed) {
273 proposed_specs[j] = Layout::kUnshardedDim;
274 }
275 } else {
276 // If producer is sharded we can set layout to shard over same
277 // mesh dim.
278 //
279 // If mesh dim is already used in the layout elsewhere it will
280 // get unset by the case above.
281 proposed_specs[j] = producer_spec;
282 }
283 }
284 FilterkAnySpecs(proposed_specs);
285 return Layout::GetLayout(proposed_specs, mesh);
286 }
287
InsertLayoutsForDTensorLayout(mlir::ModuleOp & module,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseSet<mlir::Value> & is_locked)288 mlir::LogicalResult InsertLayoutsForDTensorLayout(
289 mlir::ModuleOp& module,
290 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
291 llvm::DenseSet<mlir::Value>& is_updated,
292 llvm::DenseSet<mlir::Value>& is_locked) {
293 return mlir::failure(
294 module
295 .walk([&](mlir::TF::DTensorLayout op) -> mlir::WalkResult {
296 // Check there are no "Layout::kAny" or "kMatch" specs in the
297 // layouts.
298 for (const std::string& spec : op.layout().sharding_spec_strs())
299 if (spec == Layout::kAny || spec == Layout::kMatch)
300 return op->emitOpError()
301 << "found " << spec
302 << " as a sharding spec which is not allowed";
303 // Insert layout.
304 producer_request[op.input()].emplace(op.layout());
305 is_updated.insert(op.input());
306 is_locked.insert(op.input());
307 return mlir::WalkResult::advance();
308 })
309 .wasInterrupted());
310 }
311
312 // Runs ComputeLayout API on all ops inside graph **without** any consumer
313 // requested layout/ operand layouts populated.
InsertInitialLayoutsFromComputeLayout(mlir::ModuleOp module,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseSet<mlir::Value> & is_updated)314 mlir::LogicalResult InsertInitialLayoutsFromComputeLayout(
315 mlir::ModuleOp module,
316 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
317 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
318 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
319 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
320 consumer_requests,
321 llvm::DenseSet<mlir::Value>& is_updated) {
322 auto walk_result = module.walk([&](mlir::Operation* op) {
323 // We ignore ops that don't have either an OpResult in consumers or an
324 // OpOperand in producers. Note that if one operand is missing from
325 // producers then all operands should be missing as well as all op results
326 // from consumers and the opposite as well.
327
328 if (op->getNumOperands() > 0) {
329 if (producers.find(&op->getOpOperand(0)) == producers.end())
330 return mlir::WalkResult::advance();
331 } else if (op->getNumResults() > 0) {
332 if (consumers.find(op->getOpResult(0)) == consumers.end())
333 return mlir::WalkResult::advance();
334 } else {
335 // Note that this case should never happen (I.e. a TF ops should have
336 // either inputs or outputs, but that isn't technically guaranteed).
337 return mlir::WalkResult::advance();
338 }
339
340 auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
341 if (expander == nullptr) {
342 op->emitOpError() << "does not implement layout propagation";
343 return mlir::WalkResult::interrupt();
344 }
345
346 // Invoke ComputeLayout on `cluster_op` with empty input/consumer layouts.
347 StatusOr<llvm::DenseMap<int, Layout>> forward_result =
348 expander->ComputeLayoutForward(
349 op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
350 /*output_layouts=*/llvm::DenseMap<int, Layout>());
351 if (!forward_result.ok()) {
352 op->emitOpError() << "ComputeLayoutForward error: "
353 << forward_result.status().error_message();
354 return mlir::WalkResult::interrupt();
355 }
356 StatusOr<llvm::DenseMap<int, Layout>> backward_result =
357 expander->ComputeLayoutBackward(
358 op, /*input_layouts=*/llvm::DenseMap<int, Layout>(),
359 /*output_layouts=*/llvm::DenseMap<int, Layout>());
360 if (!backward_result.ok()) {
361 op->emitOpError() << "ComputeLayoutBackward error: "
362 << backward_result.status().error_message();
363 return mlir::WalkResult::interrupt();
364 }
365
366 // If any operand layouts were returned, add the layout to consumer requests
367 // and set the value as updated.
368 for (auto const& op_idx_and_op_layout : *backward_result) {
369 auto const& op_idx = op_idx_and_op_layout.first;
370 auto const& op_layout = op_idx_and_op_layout.second;
371 auto& operand = op->getOpOperand(op_idx);
372 const auto& producer_values = producers.lookup(&operand);
373 for (mlir::Value producer_value : producer_values) {
374 if (!consumer_requests[producer_value].count(&operand))
375 consumer_requests[producer_value][&operand] = op_layout;
376
377 is_updated.insert(producer_value);
378 }
379 }
380
381 // If any output layouts were returned, add the layout to producer requests
382 // and set the value as updated.
383 for (auto const& out_idx_and_out_layout : *forward_result) {
384 auto const& out_idx = out_idx_and_out_layout.first;
385 auto const& out_layout = out_idx_and_out_layout.second;
386 mlir::Value output_value = op->getResult(out_idx);
387 producer_request.try_emplace(output_value, out_layout);
388 is_updated.insert(output_value);
389 }
390
391 return mlir::WalkResult::advance();
392 });
393 return mlir::failure(walk_result.wasInterrupted());
394 }
395
396 // Propagates mesh and inserts initial layouts for
397 // * Any DTensorLayout ops (this handles function inputs and other ops with user
398 // layouts.
399 // * CopyToMesh
400 // * ConstOp
InsertInitialLayouts(mlir::ModuleOp & module,mlir::func::FuncOp & main_func,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_request,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseSet<mlir::Value> & is_locked)401 mlir::LogicalResult InsertInitialLayouts(
402 mlir::ModuleOp& module, mlir::func::FuncOp& main_func,
403 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
404 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
405 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
406 consumer_request,
407 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
408 llvm::DenseSet<mlir::Value>& is_updated,
409 llvm::DenseSet<mlir::Value>& is_locked) {
410 std::queue<mlir::Operation*> operations;
411
412 if (mlir::failed(InsertLayoutsForDTensorLayout(module, producer_request,
413 is_updated, is_locked)))
414 return mlir::failure();
415 return InsertInitialLayoutsFromComputeLayout(module, consumers, producers,
416 producer_request,
417 consumer_request, is_updated);
418 }
419
420 // Given a list of mlir::Values with updated producer or consumer layouts
421 // update the merged_layouts list and track which layouts actually changed.
MergeAndGetUpdatedLayouts(const llvm::DenseSet<mlir::Value> & is_locked,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseMap<mlir::Value,Layout> & merged_layouts)422 mlir::LogicalResult MergeAndGetUpdatedLayouts(
423 const llvm::DenseSet<mlir::Value>& is_locked,
424 llvm::DenseSet<mlir::Value>& is_updated,
425 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
426 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
427 consumer_requests,
428 llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
429 llvm::DenseSet<mlir::Value> updated_merge;
430 for (auto& value : is_updated) {
431 auto& producer_layout = producer_request[value];
432 if (is_locked.find(value) != is_locked.end()) {
433 // Locked values must have a producer request. If the merged_layout is
434 // not already set, then this is the first pass, so we set it and mark
435 // then entry as updated.
436 if (merged_layouts.find(value) == merged_layouts.end()) {
437 if (!producer_layout)
438 return value.getDefiningOp()->emitError() << "missing locked layout";
439 merged_layouts[value] = producer_layout.value();
440 updated_merge.insert(value);
441 }
442 continue;
443 }
444 auto merged = MergeLayouts(producer_layout, consumer_requests[value]);
445 if (!merged.ok())
446 return value.getDefiningOp()->emitOpError()
447 << merged.status().error_message();
448
449 auto current_layout = merged_layouts.find(value);
450 if (current_layout == merged_layouts.end() ||
451 current_layout->second != merged.ValueOrDie()) {
452 updated_merge.insert(value);
453 merged_layouts[value] = merged.ValueOrDie();
454 }
455 }
456
457 is_updated = updated_merge;
458 return mlir::success();
459 }
460
461 // Finds the most sharded merged layout given `layouts`.
GetMostShardedLayout(llvm::ArrayRef<Layout> layouts,mlir::Location location,absl::optional<Layout> * out)462 mlir::LogicalResult GetMostShardedLayout(llvm::ArrayRef<Layout> layouts,
463 mlir::Location location,
464 absl::optional<Layout>* out) {
465 // If there are no layouts to merge, leave the output empty.
466 if (layouts.empty()) return mlir::success();
467
468 absl::optional<Layout> layout;
469 std::map<std::string, std::set<int>> layout_map;
470 for (const Layout& layout : layouts) {
471 for (int i = 0; i < layout.rank(); ++i) {
472 const std::string& mesh_dim = layout.dim(i).sharding_spec();
473 if (mesh_dim == Layout::kUnshardedDim) continue;
474
475 layout_map[mesh_dim].insert(i);
476 }
477 }
478
479 for (auto& it : layout_map)
480 if (it.second.size() > 1) it.second.clear();
481
482 std::map<int, std::set<std::string>> dim_to_layout_map;
483 for (const auto& it : layout_map) {
484 assert(it.second.size() <= 1);
485 if (it.second.empty()) continue;
486
487 const int tensor_dim_index = *it.second.begin();
488 dim_to_layout_map[tensor_dim_index].insert(it.first);
489 }
490
491 for (auto& it : dim_to_layout_map)
492 if (it.second.size() > 1) it.second.clear();
493
494 std::vector<std::string> merged_spec;
495 assert(!layouts.empty());
496 for (int i = 0; i < layouts[0].rank(); ++i) {
497 const auto it = dim_to_layout_map.find(i);
498 if (it != dim_to_layout_map.end() && !it->second.empty()) {
499 assert(it->second.size() == 1);
500 merged_spec.emplace_back(*it->second.begin());
501 } else {
502 merged_spec.emplace_back(Layout::kUnshardedDim);
503 }
504 }
505 const auto new_layout = Layout::GetLayout(merged_spec, layouts[0].mesh());
506 if (!new_layout.ok()) {
507 return mlir::emitError(
508 location, llvm::formatv("error in layout propagation while merging "
509 "producer layouts. {0}",
510 new_layout.status().error_message()));
511 }
512 out->emplace(*new_layout);
513 return mlir::success();
514 }
515
516 // Merge layouts of mlir::Value from multiple producers into a single final
517 // layout. A mlir::Value can have multiple producers if the value is from a
518 // tf.If/tf.IfRegion op. Given multiple producer layouts of the same
519 // mlir::Value, the merging logic is as follows:
520 // 1) If a dimension can be sharded, shard the dimension as much as possible.
521 // 2) If mesh dimension is already used or two same mesh dimensions are used
522 // in different dimensions, then leave the dimension as replicated.
523 //
524 // For example:
525 // ("x", replicated) , (replicated, "y") will have ("x", "y") merged layout.
526 // ("x", replicated) , (replicated, "x") will have (replicated, replicated)
527 // merged layout.
MergeProducerLayouts(const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,const std::vector<mlir::Value> & producer_values,mlir::Location location,absl::optional<Layout> * layout_out)528 mlir::LogicalResult MergeProducerLayouts(
529 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
530 const std::vector<mlir::Value>& producer_values, mlir::Location location,
531 absl::optional<Layout>* layout_out) {
532 // If there is a single producer for mlir::Value, then return the layout
533 // from the producer.
534 absl::optional<Layout> layout;
535 if (producer_values.size() == 1) {
536 const auto it = merged_layouts.find(producer_values[0]);
537 if (it != merged_layouts.end()) *layout_out = it->second;
538 return mlir::success();
539 }
540
541 // For the case with multiple producer, merge the layouts.
542 llvm::SmallVector<Layout, 4> candidate_layouts;
543 candidate_layouts.reserve(producer_values.size());
544 for (mlir::Value value : producer_values) {
545 auto it = merged_layouts.find(value);
546 if (it == merged_layouts.end()) continue;
547 candidate_layouts.emplace_back(it->second);
548 }
549
550 if (mlir::failed(GetMostShardedLayout(candidate_layouts, location, &layout)))
551 return mlir::failure();
552
553 if (layout) *layout_out = *layout;
554 return mlir::success();
555 }
556
557 // For an op, calls the corresponding ComputeLayouts function with the data from
558 // the merged_layouts map. Records the result in the producer_request and
559 // consumer_requests maps and notes if any layouts have changed.
UpdateLayoutsForOp(mlir::Operation * op,const llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseSet<mlir::Value> & is_updated)560 mlir::LogicalResult UpdateLayoutsForOp(
561 mlir::Operation* op,
562 const llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
563 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
564 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
565 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
566 consumer_requests,
567 llvm::DenseSet<mlir::Value>& is_updated) {
568 auto* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
569 if (expander == nullptr)
570 return op->emitOpError() << "does not implement layout propagation";
571
572 // Get input and output layouts for this op from the merged_layouts map.
573 llvm::DenseMap<int, Layout> input_layouts(op->getNumOperands());
574 llvm::DenseMap<int, Layout> output_layouts(op->getNumResults());
575
576 for (int i = 0; i < op->getNumOperands(); ++i) {
577 // For inputs, we need to find the producer's mlir::Value that eventually
578 // feeds into this op. This is in the producers map.
579 // Merge different layouts for multiples producers `values`.
580 auto producer_values = producers.find(&(op->getOpOperand(i)));
581 if (producer_values == producers.end())
582 return op->emitError() << "Unable to find producer for operand " << i;
583
584 absl::optional<Layout> layout;
585 if (mlir::failed(MergeProducerLayouts(merged_layouts,
586 producer_values->getSecond(),
587 op->getLoc(), &layout)))
588 return mlir::failure();
589
590 if (layout) input_layouts[i] = *layout;
591 }
592
593 for (int i = 0; i < op->getNumResults(); ++i) {
594 auto layout = merged_layouts.find(op->getOpResult(i));
595 if (layout != merged_layouts.end()) output_layouts[i] = layout->second;
596 }
597
598 auto forward_result =
599 expander->ComputeLayoutForward(op, input_layouts, output_layouts);
600 if (!forward_result.ok()) {
601 return op->emitOpError() << "ComputeLayoutForward error: "
602 << forward_result.status().error_message();
603 }
604 const auto new_output_layouts = *forward_result;
605 auto backward_result =
606 expander->ComputeLayoutBackward(op, input_layouts, output_layouts);
607 if (!backward_result.ok()) {
608 return op->emitOpError() << "ComputeLayoutBackward error: "
609 << backward_result.status().error_message();
610 }
611 const auto new_input_layouts = *backward_result;
612
613 // Update the consumer layouts for this op.
614 for (int i = 0; i < op->getNumOperands(); ++i) {
615 mlir::OpOperand* operand = &(op->getOpOperand(i));
616 // No need to check that this exists, we already did it above.
617 const auto& producer_values = producers.find(operand);
618 const auto input_layout = new_input_layouts.find(i);
619
620 for (mlir::Value value : producer_values->getSecond()) {
621 auto& consumer_request = consumer_requests[value];
622 const auto consumer_request_from_op_operand =
623 consumer_request.find(operand);
624
625 // Update the consumer_request for this OpOperand: we respect what compute
626 // layout returns and erase the a requested layout if no layout is
627 // returned.
628 // TODO(hongjunchoi, bfontain): Consider the case when op output type is
629 // resource type with subtype information.
630 if (input_layout != new_input_layouts.end() &&
631 (consumer_request_from_op_operand == consumer_request.end() ||
632 input_layout->second != consumer_request_from_op_operand->second)) {
633 // RestoreV2 op most likely would have unknown rank upon restoring, and
634 // we relax unknown rank check for the inputs that are produced from
635 // there.
636 const bool exempt_restore_unknown_rank =
637 ValueRank(value) == -1 && value.getDefiningOp() &&
638 llvm::isa<mlir::TF::RestoreV2Op>(value.getDefiningOp());
639 if (!exempt_restore_unknown_rank &&
640 input_layout->second.rank() != ValueRank(value))
641 return op->emitOpError()
642 << "Rank for input " << i << " layout is "
643 << input_layout->second.rank() << " but actual rank is "
644 << ValueRank(value);
645
646 // If there was a layout returned and either no previous request or the
647 // request changed, insert and mark as updated.
648 consumer_request[operand] = input_layout->second;
649 is_updated.insert(value);
650 } else if (input_layout == new_input_layouts.end() &&
651 consumer_request_from_op_operand != consumer_request.end()) {
652 // If no layout was returned and there is previous request, erase the
653 // old consumer request.
654 consumer_request.erase(operand);
655 is_updated.insert(value);
656 }
657 }
658 }
659
660 // Update the producer request for this op.
661 // If the output is different from what is in the request list, update the
662 // the request and mark the mlir::Value as having an updated Layout request.
663 for (int i = 0; i < op->getNumResults(); ++i) {
664 const auto output_layout = new_output_layouts.find(i);
665 if (output_layout == new_output_layouts.end()) continue;
666 const auto& result = op->getOpResult(i);
667 if (producer_request[result] != output_layout->second) {
668 if (output_layout->second.rank() != ValueRank(result))
669 return op->emitOpError() << "Rank for output " << i << " layout is "
670 << output_layout->second.rank()
671 << " but actual rank is " << ValueRank(result);
672 producer_request[result] = output_layout->second;
673 is_updated.insert(result);
674 }
675 }
676 return mlir::success();
677 }
678
InsertDTensorLayoutOps(mlir::OpBuilder & builder,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts)679 mlir::LogicalResult InsertDTensorLayoutOps(
680 mlir::OpBuilder& builder,
681 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts) {
682 for (const auto& merged_layout : merged_layouts) {
683 // merged_layout is a pair of mlir::Value and Layout.
684 // If there is only one user of the Value and that user is a DTensorLayout
685 // op, then we can skip creating the op as the layout is already there. Note
686 // that we specifically do not allow updating a layout in an already present
687 // DTensorLayout op as we have considered them to be 'locked' throughout
688 // the algorithm.
689 const auto& users = merged_layout.first.getUsers();
690 int num_users = std::distance(users.begin(), users.end());
691 if (num_users == 1 && mlir::isa<mlir::TF::DTensorLayout>(*users.begin()))
692 continue;
693 builder.setInsertionPointAfterValue(merged_layout.first);
694 // Handles resource and variant as the real shape is embedded in the
695 // resource type elements.
696 mlir::Type value_type = GetSubtypeOrSelf(merged_layout.first);
697
698 if (auto type = value_type.dyn_cast<mlir::TensorType>()) {
699 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
700 merged_layout.first.getLoc(), merged_layout.first,
701 mlir::dtensor::LayoutAttr::get(builder.getContext(),
702 merged_layout.second),
703 mlir::TF::ShapeAttr::get(builder.getContext(), type));
704 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
705 merged_layout.first.replaceAllUsesExcept(layout_op.output(), exception);
706 } else {
707 mlir::emitError(merged_layout.first.getLoc())
708 << "value type is not TensorType as expected.";
709 }
710 }
711
712 return mlir::success();
713 }
714
GetOperationsNeedingUpdate(const llvm::DenseSet<mlir::Value> & is_updated,const llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,llvm::DenseSet<mlir::Operation * > & operations)715 void GetOperationsNeedingUpdate(
716 const llvm::DenseSet<mlir::Value>& is_updated,
717 const llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
718 llvm::DenseSet<mlir::Operation*>& operations) {
719 for (auto& value : is_updated) {
720 auto uses = consumers.find(value);
721 // Some values have no consumers (e.g. outputs of the main function).
722 if (uses != consumers.end())
723 for (auto* use : uses->second)
724 if (!mlir::isa<mlir::TF::CopyToMeshOp>(use->getOwner()))
725 operations.insert(use->getOwner());
726 // If this is an OpResult, also add the op that produces it.
727 if (value.isa<mlir::OpResult>() &&
728 !mlir::isa<mlir::TF::CopyToMeshOp>(value.getDefiningOp()))
729 operations.insert(value.getDefiningOp());
730 }
731 }
732
733 namespace {
734
735 // Custom printing class which prints out layouts and ignores DTensorLayout
736 // ops and also non registered attributes.
737 class LayoutPrinter : public mlir::OpAsmPrinter {
738 public:
LayoutPrinter(llvm::raw_ostream & os,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts)739 explicit LayoutPrinter(
740 llvm::raw_ostream& os,
741 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts)
742 : indent_level_(0),
743 os_(os),
744 current_location_(0),
745 next_argument_(0),
746 merged_layouts_(merged_layouts) {}
747
getStream() const748 llvm::raw_ostream& getStream() const override { return os_; }
749
printRegionArgument(mlir::BlockArgument arg,llvm::ArrayRef<mlir::NamedAttribute> argAttrs,bool omitType)750 void printRegionArgument(mlir::BlockArgument arg,
751 llvm::ArrayRef<mlir::NamedAttribute> argAttrs,
752 bool omitType) override {
753 printOperand(arg);
754 if (!omitType) {
755 os_ << ": ";
756 printType(arg.getType());
757 }
758 printOptionalAttrDict(argAttrs, llvm::None);
759 }
760
printOperand(mlir::Value value)761 void printOperand(mlir::Value value) override { printOperand(value, os_); }
762
763 /// Print a newline and indent the printer to the start of the current
764 /// operation.
printNewline()765 void printNewline() override {
766 os_ << "\n";
767 os_.indent(indent_level_);
768 }
769
770 // Note that we ignore the parameters printEntryBlockArgs and
771 // printBlockTerminators for simplicity.
printRegion(mlir::Region & blocks,bool printEntryBlockArgs,bool printBlockTerminators,bool printEmptyBlock=false)772 void printRegion(mlir::Region& blocks, bool printEntryBlockArgs,
773 bool printBlockTerminators,
774 bool printEmptyBlock = false) override {
775 os_ << " {\n";
776 for (auto& b : blocks.getBlocks()) print(b);
777 os_.indent(indent_level_) << "}";
778 }
779
print(mlir::Block & block)780 void print(mlir::Block& block) {
781 // Each nested block level increases the indent.
782 os_.indent(indent_level_) << "%bb(";
783 for (int i = 0; i < block.getNumArguments(); ++i) {
784 if (arguments_.find(block.getArgument(i)) == arguments_.end())
785 arguments_[block.getArgument(i)] = next_argument_++;
786 if (i > 0) os_ << ", ";
787 os_ << "%arg" << arguments_[block.getArgument(i)];
788 }
789 os_ << "):\n";
790 indent_level_ += 2;
791 for (auto& op : block.getOperations()) print(op);
792 indent_level_ -= 2;
793 }
794
795 // Prints the TF node name from `loc`.
printLoc(mlir::Location loc)796 void printLoc(mlir::Location loc) {
797 os_ << " [" << mlir::GetNameFromLoc(loc) << "]";
798 }
799
print(mlir::Operation & op)800 void print(mlir::Operation& op) {
801 // Don't print tf.DTensorLayout ops.
802 if (mlir::isa<mlir::TF::DTensorLayout>(op)) return;
803
804 // Don't print functions with empty bodies.
805 if (auto func_op = mlir::dyn_cast<mlir::func::FuncOp>(op))
806 if (func_op.empty()) return;
807
808 // Each operation is on its own line, so we start by indenting the
809 // the line.
810 os_.indent(indent_level_);
811
812 // Record a unique identifier for the op (this will be used for printing
813 // op results and operands).
814 location_[&op] = current_location_++;
815
816 // Print the outputs.
817 for (int i = 0; i < op.getNumResults(); ++i) {
818 if (i > 0) os_ << ", ";
819 printOperand(op.getOpResult(i), os_);
820 }
821
822 if (op.getNumResults() > 0) os_ << " = ";
823
824 // Some ops have a special printing method, call this if it exists.
825 if (auto opInfo = op.getRegisteredInfo()) {
826 opInfo->printAssembly(&op, *this, /*defaultDialect=*/"");
827 printLoc(op.getLoc());
828 os_ << "\n";
829 return;
830 }
831
832 // Otherwise we do a generic printing.
833 printGenericOp(&op, true);
834 printLoc(op.getLoc());
835
836 os_ << "\n";
837 }
838
839 // Print an operand, this could be both the OpResult or a BlockArgument.
840 // We also print the layout if it exists and the type.
printOperand(mlir::Value value,llvm::raw_ostream & os)841 void printOperand(mlir::Value value, llvm::raw_ostream& os) override {
842 if (auto result = value.dyn_cast<mlir::OpResult>()) {
843 // If DTensorLayout ops are already in the module, we need to skip them
844 // since we aren't printing them out.
845 if (mlir::isa<mlir::TF::DTensorLayout>(result.getDefiningOp())) {
846 printOperand(result.getDefiningOp()->getOperand(0));
847 return;
848 }
849
850 // OpResult are of the format %op_number:%result_number. We elide the
851 // result number if there is only one result (the case for most ops).
852 os << "%" << location_[result.getDefiningOp()];
853 if (result.getDefiningOp()->getNumResults() > 1)
854 os << ":" << result.getResultNumber();
855 } else if (auto argument = value.dyn_cast<mlir::BlockArgument>()) {
856 if (arguments_.find(argument) == arguments_.end())
857 arguments_[argument] = next_argument_++;
858 os << "%arg" << arguments_[argument];
859 }
860 auto layout = merged_layouts_.find(value);
861 if (layout != merged_layouts_.end()) {
862 os << " \"";
863 printLayout(layout->second, os);
864 os << "\"";
865 }
866 os << " ";
867 printType(value.getType());
868 }
869
printLayout(const Layout & layout,llvm::raw_ostream & os)870 void printLayout(const Layout& layout, llvm::raw_ostream& os) {
871 // Layouts are printed with * for an unsharded dim and the mesh dim for a
872 // sharded dim. This keeps the layout compact.
873 for (int i = 0; i < layout.rank(); ++i) {
874 if (i > 0) os << ",";
875 if (Layout::IsUnshardedDimension(layout.sharding_spec(i)))
876 os << "*";
877 else
878 os << layout.sharding_spec(i);
879 }
880 }
881
882 // A generic op consists of a name, and any of the following:
883 // * arguments,
884 // * attributes
885 // * regions
886 // These are printed out in that order.
printGenericOp(mlir::Operation * op,bool printOpName)887 void printGenericOp(mlir::Operation* op, bool printOpName) override {
888 if (printOpName) os_ << "\"" << op->getName().getStringRef() << "\"";
889 os_ << "(";
890 for (int i = 0; i < op->getNumOperands(); ++i) {
891 if (i > 0) os_ << ", ";
892 printOperand(op->getOperand(i), os_);
893 }
894 os_ << ")";
895
896 if (!op->getAttrs().empty()) {
897 std::vector<mlir::NamedAttribute> filtered;
898 for (auto attr : op->getAttrs())
899 if (*attr.getName().str().begin() != '_' &&
900 attr.getName().str() != "device")
901 filtered.emplace_back(attr);
902 if (!filtered.empty()) {
903 os_ << " {";
904 for (int i = 0; i < filtered.size(); ++i) {
905 if (i > 0) os_ << ", ";
906 printNamedAttribute(filtered[i]);
907 }
908 os_ << "}";
909 }
910 }
911
912 if (op->getNumRegions() > 0) {
913 os_ << " (";
914 for (auto& region : op->getRegions()) printRegion(region, false, false);
915 os_ << ")";
916 }
917 };
918
printSymbolName(llvm::StringRef symbolRef)919 void printSymbolName(llvm::StringRef symbolRef) override {
920 os_ << symbolRef;
921 };
922
printNamedAttribute(mlir::NamedAttribute attr)923 void printNamedAttribute(mlir::NamedAttribute attr) {
924 os_ << attr.getName().strref() << " = ";
925 printAttribute(attr.getValue());
926 }
927
printAttribute(mlir::Attribute attr)928 void printAttribute(mlir::Attribute attr) override { attr.print(os_); }
929
printType(mlir::Type type)930 void printType(mlir::Type type) override { type.print(os_); }
931
932 // The following functions are part of the printing interface but aren't
933 // needed for the compact printing form for Layout printing.
printAttributeWithoutType(mlir::Attribute attr)934 void printAttributeWithoutType(mlir::Attribute attr) override{};
printSuccessor(mlir::Block * successor)935 void printSuccessor(mlir::Block* successor) override{};
printSuccessorAndUseList(mlir::Block * successor,mlir::ValueRange succOperands)936 void printSuccessorAndUseList(mlir::Block* successor,
937 mlir::ValueRange succOperands) override{};
printOptionalAttrDict(llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::ArrayRef<llvm::StringRef> elidedAttrs)938 void printOptionalAttrDict(
939 llvm::ArrayRef<mlir::NamedAttribute> attrs,
940 llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
printOptionalAttrDictWithKeyword(llvm::ArrayRef<mlir::NamedAttribute> attrs,llvm::ArrayRef<llvm::StringRef> elidedAttrs)941 void printOptionalAttrDictWithKeyword(
942 llvm::ArrayRef<mlir::NamedAttribute> attrs,
943 llvm::ArrayRef<llvm::StringRef> elidedAttrs) override{};
944
shadowRegionArgs(mlir::Region & region,mlir::ValueRange namesToUse)945 void shadowRegionArgs(mlir::Region& region,
946 mlir::ValueRange namesToUse) override{};
printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr,mlir::ValueRange operands)947 void printAffineMapOfSSAIds(mlir::AffineMapAttr mapAttr,
948 mlir::ValueRange operands) override{};
949
printAffineExprOfSSAIds(mlir::AffineExpr expr,mlir::ValueRange dimOperands,mlir::ValueRange symOperands)950 void printAffineExprOfSSAIds(mlir::AffineExpr expr,
951 mlir::ValueRange dimOperands,
952 mlir::ValueRange symOperands) override{};
953
954 private:
955 int indent_level_;
956 llvm::raw_ostream& os_;
957 llvm::DenseMap<mlir::Operation*, int> location_;
958 int current_location_;
959 llvm::DenseMap<mlir::BlockArgument, int> arguments_;
960 int next_argument_;
961 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts_;
962 };
963
964 // Log the current set of layouts to a file marked by the hash of the input
965 // module and the stage.
LogLayoutsAndOps(const int stage,const uint64_t module_hash,const llvm::DenseMap<mlir::Value,Layout> & merged_layouts,mlir::ModuleOp & module)966 void LogLayoutsAndOps(const int stage, const uint64_t module_hash,
967 const llvm::DenseMap<mlir::Value, Layout>& merged_layouts,
968 mlir::ModuleOp& module) {
969 if (module->hasAttr(kDoNotLog) || ((ClientId() != 0) && !LogOnAllTasks()))
970 return;
971
972 std::string prefix = tensorflow::GetDumpDirFromEnvVar();
973 if (prefix.empty()) return;
974
975 auto* env = tensorflow::Env::Default();
976 auto status = env->RecursivelyCreateDir(prefix);
977 if (!status.ok()) {
978 LOG(WARNING) << "cannot create directory '" + prefix +
979 "': " + status.error_message();
980 return;
981 }
982
983 absl::StrAppend(&prefix, "/layout_propagation_v2_module_", module_hash,
984 "_stage_", stage, "_");
985 if (!tensorflow::Env::Default()->CreateUniqueFileName(&prefix, ".mlir")) {
986 LOG(WARNING) << "cannot create unique filename, won't dump MLIR module.";
987 return;
988 }
989
990 std::unique_ptr<WritableFile> file_writer;
991 status = env->NewWritableFile(prefix, &file_writer);
992 if (!status.ok()) {
993 LOG(WARNING) << "cannot open file '" + prefix +
994 "': " + status.error_message();
995 return;
996 }
997
998 // Print the module to a string before writing to the file.
999 std::string txt_module;
1000 {
1001 llvm::raw_string_ostream os(txt_module);
1002 LayoutPrinter printer(os, merged_layouts);
1003 module.print(printer);
1004 }
1005
1006 status = file_writer->Append(txt_module);
1007 if (!status.ok()) {
1008 LOG(WARNING) << "error writing to file '" + prefix +
1009 "': " + status.error_message();
1010 return;
1011 }
1012 (void)file_writer->Close();
1013 LOG(INFO) << "Dumped MLIR module to " << prefix;
1014 }
1015
1016 // Canonicalizer and DCE transformation passes may removed ops in the graph and
1017 // result in multiple consecutive DTensorLayout ops. Detect all such cases and
1018 // replace unnecessary DTensorLayout ops with Identity ops.
ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(mlir::ModuleOp module)1019 mlir::LogicalResult ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(
1020 mlir::ModuleOp module) {
1021 llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
1022 module.walk([&](mlir::TF::DTensorLayout op) { layout_ops.emplace_back(op); });
1023
1024 for (auto layout_op : llvm::reverse(layout_ops)) {
1025 auto input_op = layout_op.input().getDefiningOp();
1026 if (auto input_layout_op =
1027 llvm::dyn_cast_or_null<mlir::TF::DTensorLayout>(input_op)) {
1028 // Check that layout of input DTensorLayout op is equivalent to
1029 // the layout of its connected DTensorLayout op.
1030 if (layout_op.layout() != input_layout_op.layout())
1031 return layout_op.emitOpError(
1032 "Found inconsistent layout. This should never happen.");
1033
1034 // Replace DTensorLayout op with identity op.
1035 mlir::OpBuilder builder(layout_op);
1036 auto identity = builder.create<mlir::TF::IdentityOp>(
1037 layout_op->getLoc(), layout_op.getType(), layout_op.input());
1038 layout_op.output().replaceAllUsesWith(identity.output());
1039 layout_op.erase();
1040 }
1041 }
1042
1043 return mlir::success();
1044 }
1045
1046 // Inserts/changes DTensorLayout op after IfRegion op and results of then/else
1047 // branches to ensure that the return values of IfRegion ops are consistent.
1048 // After layout propagation, layouts of return value of tf.IfRegion op, and
1049 // layouts of terminators of then/else branches of IfRegion op may be different.
1050 // In that case, the layouts of returns values must be merged to a same layout
1051 // as return values of IfRegion op and results of then/else branches are
1052 // semantically equivalent.
InsertDTensorLayoutForIfRegionOp(const llvm::SmallVectorImpl<mlir::TF::IfRegionOp> & if_ops,mlir::MLIRContext * context)1053 mlir::LogicalResult InsertDTensorLayoutForIfRegionOp(
1054 const llvm::SmallVectorImpl<mlir::TF::IfRegionOp>& if_ops,
1055 mlir::MLIRContext* context) {
1056 for (mlir::TF::IfRegionOp if_op : if_ops) {
1057 for (mlir::OpResult if_result : if_op.getResults()) {
1058 const int result_index = if_result.getResultNumber();
1059 mlir::Value then_branch_result = if_op.then_branch()
1060 .front()
1061 .getTerminator()
1062 ->getOpOperand(result_index)
1063 .get();
1064 mlir::Value else_branch_result = if_op.else_branch()
1065 .front()
1066 .getTerminator()
1067 ->getOpOperand(result_index)
1068 .get();
1069
1070 auto if_result_layout =
1071 llvm::dyn_cast<mlir::TF::DTensorLayout>(*if_result.user_begin());
1072 auto then_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1073 *then_branch_result.getDefiningOp());
1074 auto else_result_layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(
1075 *else_branch_result.getDefiningOp());
1076 llvm::SmallVector<Layout, 4> layouts{if_result_layout.layout(),
1077 then_result_layout.layout(),
1078 else_result_layout.layout()};
1079 std::set<Layout> layouts_set{layouts.begin(), layouts.end()};
1080 if (layouts_set.size() == 1) continue;
1081
1082 absl::optional<Layout> merged_layout;
1083 if (mlir::failed(
1084 GetMostShardedLayout(layouts, if_op.getLoc(), &merged_layout)))
1085 return mlir::failure();
1086 assert(merged_layout);
1087
1088 if_result_layout->setAttr(
1089 kQualifiedLayoutAttr,
1090 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1091 then_result_layout->setAttr(
1092 kQualifiedLayoutAttr,
1093 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1094 else_result_layout->setAttr(
1095 kQualifiedLayoutAttr,
1096 mlir::dtensor::LayoutAttr::get(context, *merged_layout));
1097 }
1098 }
1099 return mlir::success();
1100 }
1101
1102 // Inserts necessary DTensorRelayout ops so that the layouts for while loops
1103 // are correct.
1104 //
1105 // Due to how while loop layout propagation is done, we may need to fix the
1106 // layouts so that the second and beyond step of the loop receive a tensor with
1107 // the correct layout.
1108 // E.g.
1109 // %b = tf.WhileRegion(%a) ({
1110 // %bb0(%arg0): # Cond
1111 // %c = tf.A(%arg0)
1112 // tf.Yield(%c)
1113 // }, {
1114 // %bb0(%arg0): # Body
1115 // %d = tf.B(%arg0)
1116 // tf.Yield(%d)
1117 // }
1118 // }
1119 // %e = tf.C(%b)
1120 //
1121 // Layout propagation treats the loop body as if it were an inlined function and
1122 // does not have a condition which fixes the layout of %d, as return value, to
1123 // match the layout of %arg0 (or %a).
1124 //
1125 // Towards this, we:
1126 // 1) Check the layout of %arg0 and see if matches the layout of the input 0
1127 // (%d) of tf.Yield.
1128 // 2) If it doesn't match we update the we insert a DTensorRelayout op between
1129 // %d and tf.Yield with the correct layout and insert a second
1130 // DTensorRelayout op after the loop body.
1131 //
1132 // NOTE: that it is necessary in general to insert both DTensorRelayout ops,
1133 // as opposed to just updating the layout of %d (which would in general be more
1134 // efficient) since %d may still be used by other ops in the loop body.
1135 //
1136 // NOTE: this is not needed for the condition as the output of the condition is
1137 // a scalar and therefore always replicated.
InsertRelayoutForWhileLoops(const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp> & while_ops,mlir::OpBuilder & builder)1138 mlir::LogicalResult InsertRelayoutForWhileLoops(
1139 const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
1140 mlir::OpBuilder& builder) {
1141 for (mlir::TF::WhileRegionOp op : while_ops) {
1142 // Get the terminator so we can check the output layouts of the loop body.
1143 mlir::Operation* yield_op = op.body().front().getTerminator();
1144 if (!mlir::isa<mlir::TF::YieldOp>(yield_op))
1145 return op->emitOpError() << "body terminator is not a Yield op.";
1146
1147 for (int i = 0; i < op.body().getNumArguments(); ++i) {
1148 // Inputs should only have one, a DTensorLayout op.
1149 mlir::Value argument = op.body().getArgument(i);
1150 if (!argument.hasOneUse())
1151 return op.emitOpError()
1152 << "body argument " << i << " doesn't have a single use.";
1153 mlir::Operation* input_layout_op = argument.getUses().begin().getUser();
1154 if (!mlir::isa<mlir::TF::DTensorLayout>(input_layout_op))
1155 return op.emitOpError() << "body argument " << i
1156 << " is not consumed by a DTensorLayout op.";
1157 const Layout input_layout =
1158 mlir::cast<mlir::TF::DTensorLayout>(input_layout_op).layout();
1159
1160 // Inputs to Yield should also be a DTensorLayout op.
1161 if (!yield_op->getOperand(i).isa<mlir::OpResult>() ||
1162 !mlir::isa<mlir::TF::DTensorLayout>(
1163 yield_op->getOperand(i).getDefiningOp()))
1164 return yield_op->emitOpError()
1165 << "argument " << i << " to is not a DTensorLayout op.";
1166 mlir::Operation* output_layout_op =
1167 yield_op->getOperand(i).getDefiningOp();
1168 const Layout output_layout =
1169 mlir::cast<mlir::TF::DTensorLayout>(output_layout_op).layout();
1170
1171 // If the layouts are equal we have nothing to do. Note that this caches
1172 // the case that that input and output are a resource, since the layout
1173 // of a resource is fixed.
1174 if (input_layout == output_layout) continue;
1175
1176 // Insert the first Relayout op (in the loop body).
1177 builder.setInsertionPointAfter(output_layout_op);
1178 if (!yield_op->getOperand(i).getType().isa<mlir::TensorType>())
1179 return yield_op->emitOpError()
1180 << "operand " << i << " does not have TensorType";
1181 mlir::TF::ShapeAttr global_shape = mlir::TF::ShapeAttr::get(
1182 builder.getContext(),
1183 yield_op->getOperand(i).getType().cast<mlir::TensorType>());
1184 mlir::TF::RelayoutOp first_relayout =
1185 builder.create<mlir::TF::RelayoutOp>(
1186 op.getLoc(), yield_op->getOperand(i).getType(),
1187 yield_op->getOperand(i), input_layout.ToString());
1188 mlir::TF::DTensorLayout first_layout_op =
1189 builder.create<mlir::TF::DTensorLayout>(
1190 op.getLoc(), first_relayout.output(),
1191 mlir::dtensor::LayoutAttr::get(builder.getContext(),
1192 input_layout),
1193 global_shape);
1194 yield_op->setOperand(i, first_layout_op.output());
1195
1196 // Insert the second relayout op after the loop itself.
1197 builder.setInsertionPointAfter(op);
1198 mlir::TF::DTensorLayout second_layout_op =
1199 builder.create<mlir::TF::DTensorLayout>(
1200 op.getLoc(), op->getResult(i),
1201 mlir::dtensor::LayoutAttr::get(builder.getContext(),
1202 input_layout),
1203 global_shape);
1204 mlir::TF::RelayoutOp second_relayout =
1205 builder.create<mlir::TF::RelayoutOp>(
1206 op.getLoc(), second_layout_op.output().getType(),
1207 second_layout_op.output(), output_layout.ToString());
1208 op->getResult(i).replaceAllUsesExcept(
1209 second_relayout.output(), llvm::SmallPtrSet<mlir::Operation*, 1>{
1210 second_layout_op.getOperation()});
1211 }
1212 }
1213 return mlir::success();
1214 }
1215
1216 // For all constants with multiple usages, clone the constants so that each
1217 // constant operation has at most 1 usage.
DuplicateConstants(mlir::ModuleOp module)1218 void DuplicateConstants(mlir::ModuleOp module) {
1219 llvm::SmallVector<mlir::TF::ConstOp, 4> const_ops;
1220 module.walk(
1221 [&](mlir::TF::ConstOp const_op) { const_ops.emplace_back(const_op); });
1222
1223 for (mlir::TF::ConstOp const_op : const_ops) {
1224 mlir::OpBuilder builder(const_op);
1225 auto uses = const_op->getUses();
1226 if (uses.empty()) return;
1227
1228 llvm::SmallDenseMap<mlir::Operation*, mlir::OpOperand*> const_use_map;
1229 mlir::OpOperand& first_use = *uses.begin();
1230 for (mlir::OpOperand& use : uses) {
1231 if (&use == &first_use) continue;
1232
1233 mlir::Operation* new_const = builder.clone(*const_op);
1234 const_use_map.try_emplace(new_const, &use);
1235 }
1236
1237 for (const auto& it : const_use_map) it.second->set(it.first->getResult(0));
1238 }
1239 }
1240
1241 // Find the root(s) values of "current_value" within the cycle, and put it
1242 // into "roots".
FindRoot(const llvm::DenseSet<mlir::Value> & is_updated,const mlir::Value & current_value,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseSet<mlir::Value> * roots)1243 void FindRoot(
1244 const llvm::DenseSet<mlir::Value>& is_updated,
1245 const mlir::Value& current_value,
1246 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1247 llvm::DenseSet<mlir::Value>* roots) {
1248 // Standard BFS to find root values of current_value.
1249 std::deque<mlir::Value> to_process;
1250 to_process.push_back(current_value);
1251
1252 llvm::DenseSet<mlir::Value> visited;
1253 visited.insert(current_value);
1254
1255 while (!to_process.empty()) {
1256 int level_size = to_process.size();
1257 for (int UNUSED = 0; UNUSED < level_size; ++UNUSED) {
1258 mlir::Value cur_val = to_process.front();
1259 to_process.pop_front();
1260
1261 // Terminating condition, if there is no defining op then this is a root.
1262 mlir::Operation* defining_op = cur_val.getDefiningOp();
1263 if (defining_op == nullptr) {
1264 roots->insert(current_value);
1265 continue;
1266 }
1267
1268 // Expand out from 'cur_val' one step closer to roots. If there was
1269 // no-one one step closer to root, then this is a root.
1270 bool is_root = true;
1271 for (int i = 0; i < defining_op->getNumOperands(); ++i) {
1272 mlir::Value operand = defining_op->getOperand(i);
1273 if (operand != cur_val && is_updated.contains(operand)) {
1274 is_root = false;
1275
1276 if (!visited.contains(operand)) {
1277 visited.insert(operand);
1278 to_process.push_back(operand);
1279 }
1280 }
1281 }
1282
1283 if (is_root) roots->insert(cur_val);
1284 }
1285 }
1286 }
1287
1288 // Finds the root value(s) of the values that have layouts cycling back and
1289 // forth in an infinite loop during layout propagation and prints the closest TF
1290 // op that consumes those root value(s). This allows users and developers to
1291 // debug the root cause of layouts that should be changed to prevent infinite
1292 // layout propagation.
FindRootsAndEmitError(mlir::ModuleOp & module,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> producers,const llvm::DenseSet<mlir::Value> & is_updated)1293 void FindRootsAndEmitError(
1294 mlir::ModuleOp& module,
1295 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers,
1296 const llvm::DenseSet<mlir::Value>& is_updated) {
1297 llvm::DenseSet<mlir::Value> roots;
1298 for (auto& value : is_updated) {
1299 FindRoot(is_updated, value, producers, &roots);
1300 }
1301 module.emitOpError()
1302 << "Maximum number of layout propagation steps reached. Unable to "
1303 "converge to a fixed layout. Please rerun with a higher limit in the "
1304 "DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS environment variable.";
1305 for (auto& root : roots) {
1306 for (mlir::OpOperand& operand : root.getUses()) {
1307 llvm::DenseMap<llvm::StringRef, mlir::Operation*> func_to_caller;
1308 llvm::SmallVector<mlir::Value, 4> skipped_values;
1309
1310 // For each root value that may need a different layout, find the
1311 // closest TF op that consumes it and print it.
1312 llvm::SmallVector<mlir::OpOperand*, 4> consuming_operands =
1313 TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values);
1314
1315 for (mlir::OpOperand* new_operand : consuming_operands) {
1316 mlir::Operation* operation = new_operand->getOwner();
1317 mlir::Location loc = operation->getLoc();
1318 operation->emitOpError() << '\n'
1319 << "The following op consumes tensors that "
1320 "may need a different layout. "
1321 "["
1322 << mlir::GetNameFromLoc(loc) << "]" << '\n';
1323 }
1324 }
1325 }
1326 }
1327 } // namespace
1328
1329 // Runs an iteration of layout propagation, where we merge producer and consumer
1330 // requests and then recompute recommended layouts on all operations that
1331 // are connected to an updated layout.
RunOneIteration(llvm::DenseSet<mlir::Value> & is_locked,llvm::DenseSet<mlir::Value> & is_updated,llvm::DenseMap<mlir::Value,absl::optional<Layout>> & producer_request,llvm::DenseMap<mlir::Value,mlir::DenseMap<mlir::OpOperand *,Layout>> & consumer_requests,llvm::DenseMap<mlir::OpOperand *,std::vector<mlir::Value>> & producers,llvm::DenseMap<mlir::Value,std::vector<mlir::OpOperand * >> & consumers,llvm::DenseMap<mlir::Value,Layout> & merged_layouts,mlir::ModuleOp & module,const uint64_t module_hash,int * stage)1332 Status RunOneIteration(
1333 llvm::DenseSet<mlir::Value>& is_locked,
1334 llvm::DenseSet<mlir::Value>& is_updated,
1335 llvm::DenseMap<mlir::Value, absl::optional<Layout>>& producer_request,
1336 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>&
1337 consumer_requests,
1338 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>>& producers,
1339 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers,
1340 llvm::DenseMap<mlir::Value, Layout>& merged_layouts, mlir::ModuleOp& module,
1341 const uint64_t module_hash, int* stage) {
1342 if (is_updated.empty()) return Status::OK();
1343 // Merge any possibly updated layouts.
1344 if (mlir::failed(
1345 MergeAndGetUpdatedLayouts(is_locked, is_updated, producer_request,
1346 consumer_requests, merged_layouts)))
1347 return errors::Internal(
1348 "MergeAndGetUpdatedLayouts failed to merge layouts.");
1349
1350 // Compile a list of operations with updated inputs or outputs.
1351 llvm::DenseSet<mlir::Operation*> operations_needing_update;
1352 GetOperationsNeedingUpdate(is_updated, consumers, operations_needing_update);
1353 is_updated.clear();
1354
1355 if (VLOG_IS_ON(2)) {
1356 LogLayoutsAndOps(*stage, module_hash, merged_layouts, module);
1357 }
1358
1359 for (auto* op : operations_needing_update) {
1360 if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1361 producer_request, consumer_requests,
1362 is_updated)))
1363 return errors::Internal("UpdateLayoutsForOp failed to update layouts.");
1364 }
1365 ++(*stage);
1366 return Status::OK();
1367 }
1368
1369 // Compares every value's layouts in `merged_a` with the ones in `merged_b`,
1370 // and store the values that differ in `changed`.
CompareMergedLayouts(const llvm::DenseMap<mlir::Value,Layout> & merged_a,const llvm::DenseMap<mlir::Value,Layout> & merged_b,llvm::DenseSet<mlir::Value> & changed)1371 Status CompareMergedLayouts(const llvm::DenseMap<mlir::Value, Layout>& merged_a,
1372 const llvm::DenseMap<mlir::Value, Layout>& merged_b,
1373 llvm::DenseSet<mlir::Value>& changed) {
1374 if (merged_a.size() != merged_b.size())
1375 return errors::Internal(
1376 "Both merged_layouts did not have the same number of set layouts.");
1377 for (const auto& value_and_layout : merged_a) {
1378 const mlir::Value value = value_and_layout.getFirst();
1379 const Layout& layout = value_and_layout.getSecond();
1380 auto value_and_layout_in_b = merged_b.find(value);
1381 if (value_and_layout_in_b == merged_b.end())
1382 return errors::Internal(
1383 "Comparing merged_layouts that contain different mlir::Value's.");
1384 if (value_and_layout_in_b->second != layout) {
1385 changed.insert(value);
1386 }
1387 }
1388 return Status::OK();
1389 }
1390
1391 // MLIR pass that propagates layout for all ops the module.
1392 struct DLayoutPropagationPassV2
1393 : public DTensorLayoutPropagationV2Base<DLayoutPropagationPassV2> {
getDependentDialectstensorflow::dtensor::DLayoutPropagationPassV21394 void getDependentDialects(mlir::DialectRegistry& registry) const override {
1395 registry.insert<mlir::dtensor::DTensorDialect>();
1396 }
1397
runOnOperationtensorflow::dtensor::DLayoutPropagationPassV21398 void runOnOperation() override {
1399 mlir::MLIRContext& context = getContext();
1400 mlir::OpBuilder builder(&context);
1401
1402 auto module = getOperation();
1403
1404 if (mlir::failed(ReplaceAuxiliaryDTensorLayoutOpsWithIdentity(module)))
1405 return signalPassFailure();
1406
1407 // In order to ensure that constant operations with multiple usages with
1408 // different consumer layout requests does not lead to replicated constant
1409 // tensors, we duplicate all constants to have at most 1 usages.
1410 // After SPMD Expansion, these duplicated constants will be merged back
1411 // during SCCP pass.
1412 DuplicateConstants(module);
1413
1414 mlir::func::FuncOp main_func =
1415 module.lookupSymbol<mlir::func::FuncOp>("main");
1416 if (!main_func) return;
1417
1418 mlir::Dialect* tf_dialect =
1419 context.getLoadedDialect<mlir::TF::TensorFlowDialect>();
1420
1421 // This maps from OpResults to a list of OpOperands that consume this.
1422 // Note that this will pass over/through
1423 // (Stateful)PartitionedCall and other control flow, directly connecting
1424 // producing ops to their consumers in the function. I.e. it presents
1425 // flattened/inlined view of the flow of data.
1426 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>> consumers;
1427 // Maintain a reverse mapping.
1428 llvm::DenseMap<mlir::OpOperand*, std::vector<mlir::Value>> producers;
1429 // For each mlir::Value this is what the producer would like to have the
1430 // layout be.
1431 llvm::DenseMap<mlir::Value, absl::optional<Layout>> producer_request;
1432 // For each mlir::Value this is what the consumers would like to have the
1433 // layout be. Note the map is in 'parallel' to the consumers map above.
1434 llvm::DenseMap<mlir::Value, mlir::DenseMap<mlir::OpOperand*, Layout>>
1435 consumer_requests;
1436 // Tracks if the layout was updated since last cycle.
1437 llvm::DenseSet<mlir::Value> is_updated;
1438 // Tracks if the layout is locked. In this case we don't pass consumer
1439 // layouts to MergeLayouts. Used for input layouts and user defined layouts.
1440 llvm::DenseSet<mlir::Value> is_locked;
1441
1442 // Create consumers and producers maps.
1443 if (mlir::failed(
1444 PopulateConsumersFromModule(&module, tf_dialect, consumers)))
1445 return signalPassFailure();
1446
1447 for (auto& consumer : consumers) {
1448 for (auto* operand : consumer.second) {
1449 if (producers.find(operand) == producers.end()) {
1450 producers[operand] = std::vector<mlir::Value>{consumer.first};
1451 } else {
1452 producers[operand].emplace_back(consumer.first);
1453 }
1454 }
1455 }
1456
1457 // Setup the initial starting conditions for the layout algorithm
1458 if (mlir::failed(InsertInitialLayouts(
1459 module, main_func, consumers, producers, consumer_requests,
1460 producer_request, is_updated, is_locked)))
1461 return signalPassFailure();
1462
1463 const auto module_hash = OpHash(module);
1464 int stage = 0;
1465
1466 llvm::DenseMap<mlir::Value, Layout> merged_layouts;
1467 Status status;
1468
1469 while (!is_updated.empty() && stage < kLayoutPropagationMaxStages) {
1470 ++stage;
1471 int steps = 0;
1472 // Step 1. Run the layout propagation v2 until convergence or max steps.
1473 while (!is_updated.empty() && steps < LayoutPropagationMaxSteps()) {
1474 Status status = RunOneIteration(
1475 is_locked, is_updated, producer_request, consumer_requests,
1476 producers, consumers, merged_layouts, module, module_hash, &steps);
1477 if (!status.ok()) {
1478 module.emitOpError() << "Failure running iteration.";
1479 return signalPassFailure();
1480 }
1481 }
1482 if (VLOG_IS_ON(2)) {
1483 LOG(INFO) << "Failed to converge in stage " << stage;
1484 }
1485 // Step 2. If we are here, then we failed to converge, and likely
1486 // there is an oscillation of layouts. Detect all the edges that are
1487 // changing layouts.
1488 llvm::DenseMap<mlir::Value, Layout> merged_layouts_at_max_steps =
1489 merged_layouts;
1490 llvm::DenseSet<mlir::Value> changed;
1491 int previous_change_size = -1;
1492
1493 while (changed.size() > previous_change_size) {
1494 if (!RunOneIteration(is_locked, is_updated, producer_request,
1495 consumer_requests, producers, consumers,
1496 merged_layouts, module, module_hash, &steps)
1497 .ok()) {
1498 module.emitOpError() << "Failure running iteration.";
1499 return signalPassFailure();
1500 }
1501 if (!CompareMergedLayouts(merged_layouts_at_max_steps, merged_layouts,
1502 changed)
1503 .ok()) {
1504 module.emitOpError() << "Failure comparing merged layouts.";
1505 return signalPassFailure();
1506 }
1507 previous_change_size = changed.size();
1508 }
1509
1510 // Step 3. Layouts that haven't changed means they're not part of the
1511 // cyclic problem, so freeze them.
1512 for (const auto& value_and_layout : merged_layouts) {
1513 const mlir::Value value = value_and_layout.getFirst();
1514 if (changed.find(value) == changed.end()) {
1515 is_locked.insert(value);
1516 }
1517 }
1518 // Step 4. Any information corresponding to the changed layouts
1519 // should be disinfected, we do this by clearing all information
1520 // regarding them.
1521 for (const mlir::Value changed_value : changed) {
1522 producer_request.erase(changed_value);
1523 consumer_requests.erase(changed_value);
1524 merged_layouts.erase(changed_value);
1525 }
1526
1527 // Step 5. ComputeLayout again on all the ops linked to the changed
1528 // layouts. The next iteration will take this information and merge again.
1529 llvm::DenseSet<mlir::Operation*> operations_needing_update;
1530 is_updated = changed;
1531 GetOperationsNeedingUpdate(is_updated, consumers,
1532 operations_needing_update);
1533 is_updated.clear();
1534
1535 for (auto* op : operations_needing_update) {
1536 if (mlir::failed(UpdateLayoutsForOp(op, producers, merged_layouts,
1537 producer_request, consumer_requests,
1538 is_updated))) {
1539 module.emitOpError() << "Failure in UpdateLayoutsForOp.";
1540 return signalPassFailure();
1541 }
1542 }
1543 }
1544
1545 if (stage >= kLayoutPropagationMaxStages) {
1546 FindRootsAndEmitError(module, producers, is_updated);
1547 return signalPassFailure();
1548 }
1549
1550 if (mlir::failed(
1551 CopyLayoutsForSkippedOps(module, tf_dialect, merged_layouts)))
1552 return signalPassFailure();
1553
1554 if (VLOG_IS_ON(2)) {
1555 LogLayoutsAndOps(stage, module_hash, merged_layouts, module);
1556 }
1557
1558 if (!AllOpResultsHaveLayouts(&module, tf_dialect, merged_layouts))
1559 return signalPassFailure();
1560
1561 if (mlir::failed(InsertDTensorLayoutOps(builder, merged_layouts)))
1562 return signalPassFailure();
1563
1564 // Handle layout of control flow operations.
1565 llvm::SmallVector<mlir::TF::IfRegionOp, 4> if_ops;
1566 llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
1567 module.walk([&](mlir::Operation* op) {
1568 if (auto if_op = llvm::dyn_cast<mlir::TF::IfRegionOp>(op))
1569 if_ops.emplace_back(if_op);
1570 else if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
1571 while_ops.emplace_back(while_op);
1572 });
1573
1574 if (mlir::failed(InsertRelayoutForWhileLoops(while_ops, builder)))
1575 return signalPassFailure();
1576
1577 if (mlir::failed(
1578 InsertDTensorLayoutForIfRegionOp(if_ops, builder.getContext())))
1579 return signalPassFailure();
1580 };
1581 };
1582
1583 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorLayoutPropagationPassV2()1584 CreateDTensorLayoutPropagationPassV2() {
1585 return std::make_unique<DLayoutPropagationPassV2>();
1586 }
1587
1588 } // namespace dtensor
1589 } // namespace tensorflow
1590