1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges the one that contains operation marked to run on TPU.
18
19 #include <algorithm>
20 #include <iterator>
21 #include <queue>
22 #include <tuple>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/Block.h" // from @llvm-project
36 #include "mlir/IR/Builders.h" // from @llvm-project
37 #include "mlir/IR/Location.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/SymbolTable.h" // from @llvm-project
40 #include "mlir/IR/UseDefLists.h" // from @llvm-project
41 #include "mlir/IR/Visitors.h" // from @llvm-project
42 #include "mlir/Pass/Pass.h" // from @llvm-project
43 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
44 #include "mlir/Support/LLVM.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
50 #include "tensorflow/core/platform/logging.h"
51
52 #define DEBUG_TYPE "tf-executor-tpu-v1-island-coarsening"
53
54 namespace mlir {
55 namespace tf_executor {
56
57 namespace {
58
59 constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
60 constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster";
61
62 // This pass is a variant of the island coarsening that is limited to
63 // TPU-annotated operations and intended to preserve backward compatibility with
64 // TFv1.
65 struct TpuV1BridgeExecutorIslandCoarsening
66 : public TF::TpuV1BridgeExecutorIslandCoarseningPassBase<
67 TpuV1BridgeExecutorIslandCoarsening> {
68 void runOnOperation() override;
69 };
70
71 // Returns name of TPU cluster, if op belongs to a TPU cluster. Otherwise,
72 // returns `llvm::None`.
GetTpuClusterName(Operation * op)73 llvm::Optional<llvm::StringRef> GetTpuClusterName(Operation* op) {
74 if (auto tpu_status = op->getAttrOfType<StringAttr>(kTpuStatusAttr)) {
75 // Borrow cluster name from TPU status (for `TPUCompilationResult` op).
76 return tpu_status.getValue();
77 }
78 auto device_type = op->getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr);
79 if (!device_type || device_type.getValue() != TF::kTpuDevice) {
80 // Op does not belong to a TPU cluster.
81 return llvm::None;
82 }
83 // Op belongs to a TPU cluster.
84 if (auto replication_info =
85 op->getAttrOfType<StringAttr>(TF::kReplicationInfoAttr)) {
86 // Borrow cluster name from replication info.
87 return replication_info.getValue();
88 }
89 // Use special cluster name for non-replicated case.
90 return kNoReplicationCluster;
91 }
92
HasDataDependencyWithUnscheduledOp(Operation & op,Block * block,SmallPtrSet<Operation *,16> & unscheduled_ops)93 bool HasDataDependencyWithUnscheduledOp(
94 Operation& op, Block* block, SmallPtrSet<Operation*, 16>& unscheduled_ops) {
95 WalkResult ready_to_schedule = op.walk([&](Operation* nested_op) {
96 for (Value operand : nested_op->getOperands()) {
97 Operation* defining_op = operand.getDefiningOp();
98 if (!defining_op) continue;
99 Operation* producer_in_block = block->findAncestorOpInBlock(*defining_op);
100 if (producer_in_block && producer_in_block != &op &&
101 unscheduled_ops.count(producer_in_block)) {
102 // Found an operand that isn't scheduled yet, interrupt the walk.
103 return WalkResult::interrupt();
104 }
105 }
106 return WalkResult::advance();
107 });
108 return ready_to_schedule.wasInterrupted();
109 }
110
HasControlDependencyWithUnscheduledOp(Operation & op,Block * block,SmallPtrSet<Operation *,16> & unscheduled_ops)111 bool HasControlDependencyWithUnscheduledOp(
112 Operation& op, Block* block, SmallPtrSet<Operation*, 16>& unscheduled_ops) {
113 IslandOp island_op = dyn_cast<IslandOp>(op);
114 if (!island_op) {
115 return false;
116 }
117 for (Value input : island_op.controlInputs()) {
118 Operation* defining_op = input.getDefiningOp();
119 if (!defining_op) continue;
120 Operation* producer_in_block = block->findAncestorOpInBlock(*defining_op);
121 if (producer_in_block && producer_in_block != &op &&
122 unscheduled_ops.count(producer_in_block)) {
123 // Found an operand that isn't scheduled yet, return true.
124 return true;
125 }
126 }
127 return false;
128 }
129
130 // Sorts the operations in the provided range to enforce dominance.
131 // This is useful after fusing / reorganizing Operations in a block and later
132 // needing to readjust the ordering to ensure dominance.
SortTopologically(Block::iterator begin,Block::iterator end)133 LogicalResult SortTopologically(Block::iterator begin, Block::iterator end) {
134 Block* block = begin->getBlock();
135 // Either sort from `begin` to end of block or both `begin` and
136 // `end` should belong to the same block.
137 assert(end == block->end() ||
138 end->getBlock() == block && "ops must be in the same block");
139
140 // Track the ops that still need to be scheduled in a set.
141 SmallPtrSet<Operation*, 16> unscheduled_ops;
142 for (Operation& op : llvm::make_range(begin, end))
143 unscheduled_ops.insert(&op);
144
145 Block::iterator last_scheduled_op = begin;
146 while (!unscheduled_ops.empty()) {
147 bool scheduled_at_least_once = false;
148 // Loop over the ops that are not sorted yet, try to find the ones "ready",
149 // i.e. the ones for which there aren't any operand produced by an op in the
150 // set, and "schedule" it (move it before the last_scheduled_op).
151 for (Operation& op : llvm::make_range(last_scheduled_op, end)) {
152 if (HasDataDependencyWithUnscheduledOp(op, block, unscheduled_ops) ||
153 HasControlDependencyWithUnscheduledOp(op, block, unscheduled_ops)) {
154 continue;
155 }
156 unscheduled_ops.erase(&op);
157 if (Block::iterator(op) != last_scheduled_op)
158 op.moveBefore(block, last_scheduled_op);
159 else
160 ++last_scheduled_op;
161 scheduled_at_least_once = true;
162 }
163 if (!scheduled_at_least_once) return failure();
164 }
165 return success();
166 }
167
168 // Looks for an IslandOp that wraps a single operation tagged with the
169 // _replication_info attribute, and merges it with all the following operations
170 // in the block. Sets the `changed` boolean to true if any island is merged.
171 // Returns a failure if a cycle prevents the merge from happening correctly
172 // without breaking dominance. The IR is left in invalid state in case of
173 // failure.
CollectCandidateIslands(llvm::function_ref<bool (llvm::StringRef,Operation *)> is_op_calling_func_for_cluster,Operation * op,StringRef cluster_name,SmallPtrSet<Operation *,16> & islands_set,SmallPtrSet<Operation *,16> & wrapped_ops)174 void CollectCandidateIslands(
175 llvm::function_ref<bool(llvm::StringRef, Operation*)>
176 is_op_calling_func_for_cluster,
177 Operation* op, StringRef cluster_name,
178 SmallPtrSet<Operation*, 16>& islands_set,
179 SmallPtrSet<Operation*, 16>& wrapped_ops) {
180 for (Operation& candidate_op : llvm::make_early_inc_range(
181 llvm::make_range(op->getIterator(), op->getBlock()->end()))) {
182 IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
183 if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
184 // Check if we have an operation with the expected attribute.
185 Operation& candidate_wrapped_op = candidate_island.GetBody().front();
186
187 // The op might be a special TPU input/output op and may have been already
188 // added to the list of islands to be merged.
189 if (wrapped_ops.contains(&candidate_wrapped_op)) {
190 continue;
191 }
192
193 llvm::Optional<llvm::StringRef> result =
194 GetTpuClusterName(&candidate_wrapped_op);
195 llvm::StringRef candidate_cluster_name;
196 if (result.has_value()) {
197 candidate_cluster_name = result.getValue();
198 } else if (is_op_calling_func_for_cluster(cluster_name,
199 &candidate_wrapped_op)) {
200 candidate_cluster_name = cluster_name;
201 }
202 if (candidate_cluster_name != cluster_name) continue;
203
204 // Add the current op to the set of ops which are planned to be merged into
205 // one cluster.
206 islands_set.insert(candidate_island);
207 wrapped_ops.insert(&candidate_wrapped_op);
208 }
209 }
210
CreateMergedIsland(IslandOp island,SmallVector<IslandOp,16> & islands,SmallPtrSet<Operation *,16> & wrapped_ops)211 IslandOp CreateMergedIsland(IslandOp island, SmallVector<IslandOp, 16>& islands,
212 SmallPtrSet<Operation*, 16>& wrapped_ops) {
213 // Compute the result of the merged island, these are the values produced by
214 // the islands that are merged if they have a use in an island not merged,
215 // i.e. a value that escapes.
216 llvm::SmallVector<Type, 4> result_types;
217 for (IslandOp new_op : islands) {
218 for (Value result : new_op.outputs()) {
219 if (llvm::any_of(result.getUsers(), [&](OpOperand user) {
220 return !wrapped_ops.count(user.getOwner());
221 }))
222 result_types.push_back(result.getType());
223 }
224 }
225
226 IslandOp new_island = OpBuilder(island).create<IslandOp>(
227 island.getLoc(), result_types,
228 /*control=*/ControlType::get(island.getContext()),
229 /*controlInputs=*/island.getOperands());
230 new_island.body().push_back(new Block);
231
232 // Move the operations in the new island, gather the results of the new yield.
233 Block& island_body = new_island.GetBody();
234 SmallVector<Value, 16> yield_operands;
235 for (IslandOp island : islands) {
236 Operation& wrapped_op = island.GetBody().front();
237 wrapped_op.moveBefore(&island_body, island_body.end());
238
239 // For every result of the wrapped_op, it needs to get passed to the yield
240 // operation, only if it escapes the island.
241 for (auto result : llvm::zip(island.outputs(), wrapped_op.getResults())) {
242 if (llvm::any_of(std::get<0>(result).getUsers(), [&](OpOperand user) {
243 return !wrapped_ops.count(user.getOwner());
244 }))
245 yield_operands.push_back(std::get<1>(result));
246 }
247 }
248 OpBuilder::atBlockEnd(&island_body)
249 .create<YieldOp>(new_island.getLoc(), yield_operands);
250
251 // remap results of the new islands to the user outside of the island.
252 int current_result = 0;
253 Value control = new_island.control();
254 for (IslandOp island : islands) {
255 YieldOp yield_op = island.GetYield();
256 for (const auto& idx_result : llvm::enumerate(island.outputs())) {
257 Value result = idx_result.value();
258
259 bool has_external_use = false;
260 for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
261 if (wrapped_ops.count(use.getOwner()))
262 use.set(yield_op.getOperand(idx_result.index()));
263 else
264 has_external_use = true;
265 }
266 if (has_external_use) {
267 result.replaceAllUsesWith(new_island.getResult(current_result));
268 ++current_result;
269 }
270 }
271 island.control().replaceAllUsesWith(control);
272 island.erase();
273 }
274 return new_island;
275 }
276
277 // Looks for an IslandOp that wraps a single operation tagged with the
278 // _replication_info attribute, and merges it with all the following operations
279 // in the block. Sets the `changed` boolean to true if any island is merged.
280 // Returns a failure if a cycle prevents the merge from happening correctly
281 // without breaking dominance. The IR is left in invalid state in case of
282 // failure.
MergeIsland(llvm::function_ref<bool (StringRef,Operation *)> is_op_calling_func_for_cluster,llvm::SmallDenseMap<StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map,Operation * op,bool * changed)283 LogicalResult MergeIsland(
284 llvm::function_ref<bool(StringRef, Operation*)>
285 is_op_calling_func_for_cluster,
286 llvm::SmallDenseMap<StringRef, llvm::SmallDenseSet<Operation*>>&
287 cluster_to_tpu_ops_map,
288 Operation* op, bool* changed) {
289 // Find the first island wrapping a single operation with the
290 // `_replication_info` attribute, it'll be used as the root of the algorithm
291 // to find the other operations that are part of the same cluster.
292 IslandOp island = dyn_cast<IslandOp>(*op);
293 if (!island || !island.WrapsSingleOp()) return success();
294 Operation& wrapped_op = island.GetBody().front();
295
296 llvm::Optional<llvm::StringRef> result = GetTpuClusterName(&wrapped_op);
297 if (!result.has_value()) return success();
298 llvm::StringRef cluster_name = result.getValue();
299
300 // We found a _replication_info, let's build an island for the full cluster!
301 LLVM_DEBUG(llvm::dbgs() << "Processing candidate island: "
302 << *island.getOperation() << "\n");
303
304 // Collect the islands to merge together in this new cluster starting with the
305 // given island.
306 SmallVector<IslandOp, 16> islands;
307 SmallPtrSet<Operation*, 16> islands_set;
308 SmallPtrSet<Operation*, 16> wrapped_ops;
309
310 CollectCandidateIslands(is_op_calling_func_for_cluster, op, cluster_name,
311 islands_set, wrapped_ops);
312
313 if (cluster_to_tpu_ops_map.count(cluster_name)) {
314 for (auto tpu_op : cluster_to_tpu_ops_map[cluster_name]) {
315 islands_set.insert(tpu_op);
316 wrapped_ops.insert(&dyn_cast<IslandOp>(*tpu_op).GetBody().front());
317 }
318 }
319
320 // Get the sequential order of the candidate islands in the block.
321 // Later we can merge the candidate islands in this order.
322 // The dominance is guaranteed in this order.
323 for (Operation& candidate_op : llvm::make_early_inc_range(
324 llvm::make_range(op->getBlock()->begin(), op->getBlock()->end()))) {
325 IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
326 if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
327 if (islands_set.contains(candidate_island)) {
328 islands.push_back(candidate_island);
329 }
330 }
331
332 // If no other island was found to merge with the existing one, just move on.
333 if (islands.size() <= 1) return success();
334
335 *changed = true;
336 Operation* first_op_after = islands.back()->getNextNode();
337
338 // We create the merged island at the location of the first island that was
339 // merged (excluding special TPU input/output ops).
340 IslandOp new_island = CreateMergedIsland(island, islands, wrapped_ops);
341
342 // Ensure dominance by sorting the range of islands that were merged.
343 return SortTopologically(Block::iterator(new_island.getOperation()),
344 Block::iterator(first_op_after));
345 }
346
347 // Returns all functions that can be reached from TPUPartitionedCall ops.
FindTPUPartitionedCallReachableFunctions(ModuleOp module)348 SmallPtrSet<Operation*, 16> FindTPUPartitionedCallReachableFunctions(
349 ModuleOp module) {
350 SymbolTableCollection table;
351 SymbolUserMap symbol_map(table, module);
352 llvm::DenseMap<func::FuncOp, llvm::DenseSet<func::FuncOp>> caller_callee_map;
353 // Creates work queue for determining reachability below.
354 std::queue<func::FuncOp> function_worklist;
355
356 for (auto func : module.getOps<func::FuncOp>()) {
357 for (auto user : symbol_map.getUsers(func)) {
358 // Populates work queue with func ops called from TPUPartionedCall.
359 if (llvm::isa<TF::TPUPartitionedCallOp>(user)) {
360 function_worklist.push(func);
361 }
362 // Populates caller to called func map.
363 if (func::FuncOp caller = user->getParentOfType<func::FuncOp>()) {
364 caller_callee_map[caller].insert(func);
365 }
366 }
367 }
368
369 // Determines reached ops starting from TPUPartionedCall ops
370 // and iteratively descending through called ops.
371 SmallPtrSet<Operation*, 16> reachable_functions;
372 while (!function_worklist.empty()) {
373 func::FuncOp caller = function_worklist.front();
374 function_worklist.pop();
375 if (reachable_functions.insert(caller).second) {
376 for (auto callee : caller_callee_map[caller]) {
377 function_worklist.push(callee);
378 }
379 }
380 }
381 return reachable_functions;
382 }
383
384 // valid means all the ops in the vector are belong to the same cluster.
is_valid_special_tpu_op(std::vector<IslandOp> & ops,llvm::StringRef cluster_name,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map)385 bool is_valid_special_tpu_op(
386 std::vector<IslandOp>& ops, llvm::StringRef cluster_name,
387 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
388 cluster_to_tpu_op_map) {
389 for (IslandOp op : ops) {
390 Operation* wrapped_op = &op.GetBody().front();
391 llvm::Optional<llvm::StringRef> wrapped_op_cluster_name =
392 GetTpuClusterName(wrapped_op);
393
394 bool op_has_inconsistent_cluster_name =
395 wrapped_op_cluster_name.has_value() &&
396 !wrapped_op_cluster_name.getValue().equals(cluster_name);
397
398 if (op_has_inconsistent_cluster_name) {
399 return false;
400 }
401 }
402 return true;
403 }
404
collect_input_defining_islands(IslandOp op,std::vector<IslandOp> & ops)405 void collect_input_defining_islands(IslandOp op, std::vector<IslandOp>& ops) {
406 Operation* wrapped_op = &op.GetBody().front();
407 for (Value operand : wrapped_op->getOperands()) {
408 IslandOp wrapper = dyn_cast_or_null<IslandOp>(operand.getDefiningOp());
409 if (!wrapper || !wrapper.WrapsSingleOp()) continue;
410 ops.push_back(wrapper);
411 }
412 }
413
collect_output_users_islands(IslandOp op,std::vector<IslandOp> & ops)414 void collect_output_users_islands(IslandOp op, std::vector<IslandOp>& ops) {
415 for (Value result : op->getResults()) {
416 for (OpOperand use : result.getUsers()) {
417 IslandOp wrapper =
418 dyn_cast_or_null<IslandOp>(use.getOwner()->getParentOp());
419 if (!wrapper || !wrapper.WrapsSingleOp()) continue;
420 ops.push_back(wrapper);
421 }
422 }
423 }
424
AddSpecialTpuOps(IslandOp candidate_island,llvm::StringRef cluster_name,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map,SmallPtrSetImpl<Operation * > & visited_wrapped_ops,bool incoming)425 bool AddSpecialTpuOps(
426 IslandOp candidate_island, llvm::StringRef cluster_name,
427 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
428 cluster_to_tpu_op_map,
429 SmallPtrSetImpl<Operation*>& visited_wrapped_ops, bool incoming) {
430 std::queue<IslandOp> op_worklist;
431 std::vector<IslandOp> ops;
432
433 op_worklist.push(candidate_island);
434
435 while (!op_worklist.empty()) {
436 IslandOp current_op = op_worklist.front();
437 op_worklist.pop();
438 ops.clear();
439 if (incoming) {
440 collect_input_defining_islands(current_op, ops);
441 } else {
442 collect_output_users_islands(current_op, ops);
443 }
444 for (IslandOp wrapper : ops) {
445 Operation* wrapped_op = &wrapper.GetBody().front();
446 std::vector<IslandOp> child_ops;
447 if (incoming) {
448 // Looks at captured operands of `candidate_wrapped_op` to bring special
449 // TPU ops such as tf.TPUReplicatedInput and tf.TPUPartitionedInput into
450 // the island as well. These ops are brought in only if they do not
451 // already have a cluster assigned to them (via `_replication_info`
452 // attribute value).
453 // `tf.Identity` op is also treated as special tpu ops since it can play
454 // a role as connection between `tf.TPUReplicatedInput` or
455 // `tf.TPUPartitionedInput`. For example, we have the follow pseudocode:
456 // %0 = tf_executor.island wraps "tf.OpA" (){_replication_info = 'c'}
457 // %1 = tf_executor.island wraps "tf.Identity(%0)
458 // %2 = tf_executor.island wraps "tf.TPUReplicatedInput"(%1)
459
460 if (!isa<TF::TPUReplicatedInputOp, TF::TPUPartitionedInputOp,
461 TF::IdentityOp>(wrapped_op))
462 continue;
463 collect_output_users_islands(wrapper, child_ops);
464 } else {
465 // Looks at the results of `candidate_island` to bring special TPU
466 // ops such as tf.TPUReplicatedOutput and tf.TPUPartitionedOutput into
467 // the island as well. These ops are brought in only if they do not
468 // already have cluster (`_tpu_replicate` attribute) assigned to them.
469 // `tf.Identity` op is also treated as special tpu ops since it can play
470 // a role as connection between `tf.TPUReplicatedOutput` or
471 // `tf.TPUPartitionedInput`.
472 if (!isa<TF::TPUReplicatedOutputOp, TF::TPUPartitionedOutputOp,
473 TF::IdentityOp>(wrapped_op))
474 continue;
475 collect_input_defining_islands(wrapper, child_ops);
476 }
477 if (!is_valid_special_tpu_op(child_ops, cluster_name,
478 cluster_to_tpu_op_map)) {
479 return false;
480 }
481
482 // Only inputs/outputs that do not have a cluster name assigned are
483 // considered for special handling. Otherwise, island coarsening logic
484 // should be able to handle it.
485 if (wrapped_op->hasAttrOfType<StringAttr>(TF::kReplicationInfoAttr))
486 continue;
487 if (visited_wrapped_ops.contains(wrapped_op)) continue;
488 op_worklist.push(wrapper);
489 cluster_to_tpu_op_map[cluster_name].insert(wrapper);
490 visited_wrapped_ops.insert(wrapped_op);
491 }
492 }
493 return true;
494 }
495
CollectSpecialTpuOps(llvm::function_ref<bool (llvm::StringRef,Operation *)> is_op_calling_func_for_cluster,Operation * op,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map,SmallPtrSet<Operation *,16> & visited_wrapped_ops)496 LogicalResult CollectSpecialTpuOps(
497 llvm::function_ref<bool(llvm::StringRef, Operation*)>
498 is_op_calling_func_for_cluster,
499 Operation* op,
500 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
501 cluster_to_tpu_op_map,
502 SmallPtrSet<Operation*, 16>& visited_wrapped_ops) {
503 IslandOp island = dyn_cast<IslandOp>(*op);
504 if (!island || !island.WrapsSingleOp()) return success();
505 Operation& wrapped_op = island.GetBody().front();
506
507 if (visited_wrapped_ops.contains(&wrapped_op)) return success();
508
509 llvm::Optional<llvm::StringRef> result = GetTpuClusterName(&wrapped_op);
510 if (!result.has_value()) return success();
511 llvm::StringRef cluster_name = result.getValue();
512
513 visited_wrapped_ops.insert(&wrapped_op);
514
515 if (!AddSpecialTpuOps(island, cluster_name, cluster_to_tpu_op_map,
516 visited_wrapped_ops, /*incoming=*/true)) {
517 return failure();
518 }
519 if (!AddSpecialTpuOps(island, cluster_name, cluster_to_tpu_op_map,
520 visited_wrapped_ops, /*incoming=*/false)) {
521 return failure();
522 }
523 return success();
524 }
525
526 // Whenever we find an Identity op that is unqualified, we remove this Identity
527 // op from the list `tpu_ops`. An unqualified Identity op indicates either its
528 // inputs or its outputs do not belong to the same cluster.
ExcludeIdentityOp(llvm::SmallDenseSet<Operation * > & tpu_ops,llvm::StringRef & target_cluster_name,bool incoming)529 bool ExcludeIdentityOp(llvm::SmallDenseSet<Operation*>& tpu_ops,
530 llvm::StringRef& target_cluster_name, bool incoming) {
531 for (auto iter = tpu_ops.begin(); iter != tpu_ops.end(); iter++) {
532 auto island_op = llvm::dyn_cast<IslandOp>(*iter);
533 if (llvm::dyn_cast_or_null<TF::IdentityOp>(island_op.GetBody().front())) {
534 if (island_op.outputs().use_empty()) {
535 tpu_ops.erase(iter);
536 return true;
537 }
538 std::vector<IslandOp> ops;
539 if (incoming) {
540 collect_output_users_islands(island_op, ops);
541 } else {
542 collect_input_defining_islands(island_op, ops);
543 }
544 for (IslandOp wrapper : ops) {
545 Operation* wrapped_op = &wrapper.GetBody().front();
546 auto cluster_name = GetTpuClusterName(wrapped_op);
547 if (cluster_name.hasValue() &&
548 cluster_name.getValue() != target_cluster_name) {
549 tpu_ops.erase(iter);
550 return true;
551 }
552 if (!cluster_name.hasValue() &&
553 !tpu_ops.count(wrapper.getOperation())) {
554 tpu_ops.erase(iter);
555 return true;
556 }
557 }
558 }
559 }
560 return false;
561 }
562
ExcludeUnqualifiedIdentityOp(llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map,bool incoming)563 void ExcludeUnqualifiedIdentityOp(
564 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
565 cluster_to_tpu_ops_map,
566 bool incoming) {
567 for (auto& [target_cluster_name, tpu_ops] : cluster_to_tpu_ops_map) {
568 bool changed = true;
569 while (changed) {
570 changed = ExcludeIdentityOp(tpu_ops, target_cluster_name, incoming);
571 }
572 }
573 }
574
ExcludeUnqualifiedIdentityOp(llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map)575 void ExcludeUnqualifiedIdentityOp(
576 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
577 cluster_to_tpu_ops_map) {
578 ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map, /*incoming=*/true);
579 ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map, /*incoming=*/false);
580 }
581
582 // Erase Identity op which does not contain `_replication_info` in the merged
583 // island.
EraseIdentityWithNoReplicationInfo(Block & graph_body)584 void EraseIdentityWithNoReplicationInfo(Block& graph_body) {
585 for (Operation& island_op : graph_body) {
586 IslandOp island = dyn_cast<IslandOp>(island_op);
587 if (!island || island.WrapsSingleOp()) continue;
588 for (Operation& op : llvm::make_early_inc_range(island.GetBody())) {
589 llvm::Optional<llvm::StringRef> cluster_name = GetTpuClusterName(&op);
590 if (cluster_name.hasValue()) continue;
591 if (auto identity_op = llvm::dyn_cast_or_null<TF::IdentityOp>(op)) {
592 auto identity_input = identity_op.input();
593 auto output = identity_op.output();
594 output.replaceAllUsesWith(identity_input);
595 identity_op.erase();
596 }
597 }
598 }
599 }
600
runOnOperation()601 void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
602 SymbolTable symbol_table(getOperation());
603
604 // Map tpu cluster names to the functions that contain operations for this
605 // cluster.
606 DenseMap<StringRef, DenseSet<func::FuncOp>> tpu_funcs;
607 for (func::FuncOp func_op : getOperation().getOps<func::FuncOp>()) {
608 func_op.walk([&](Operation* op) {
609 llvm::Optional<llvm::StringRef> cluster_name_opt = GetTpuClusterName(op);
610 if (cluster_name_opt.has_value()) {
611 tpu_funcs[cluster_name_opt.getValue()].insert(func_op);
612 }
613 });
614 }
615
616 // Return true if the operation is containing a reference to a function
617 // containing operations for this cluster.
618 auto is_op_calling_func_for_cluster = [&](llvm::StringRef cluster,
619 Operation* op) {
620 auto funcs_for_cluster = tpu_funcs.find(cluster);
621 assert(funcs_for_cluster != tpu_funcs.end());
622 assert(!funcs_for_cluster->second.empty());
623 if (funcs_for_cluster->second.size() == 1) return false;
624 for (NamedAttribute attr : op->getAttrs()) {
625 auto symbol_ref = attr.getValue().dyn_cast<FlatSymbolRefAttr>();
626 if (!symbol_ref) continue;
627 func::FuncOp callee =
628 symbol_table.lookup<func::FuncOp>(symbol_ref.getValue());
629 if (!callee) continue;
630 if (funcs_for_cluster->second.count(callee)) return true;
631 }
632 return false;
633 };
634
635 // Populates skip set with functions reachable from TPUPartionedCall ops.
636 const auto functions_to_skip =
637 FindTPUPartitionedCallReachableFunctions(getOperation());
638 for (func::FuncOp func_op : getOperation().getOps<func::FuncOp>()) {
639 if (functions_to_skip.contains(func_op)) {
640 OpBuilder builder(func_op);
641 // Mark this function as being skipped in island outlining.
642 func_op->setAttr(mlir::TF::kSkipIslandOutlining,
643 builder.getBoolAttr(true));
644 continue;
645 }
646
647 func_op.walk([&](GraphOp graph) {
648 Block& graph_body = graph.GetBody();
649 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>
650 cluster_to_tpu_ops_map;
651 SmallPtrSet<Operation*, 16> visited_ops;
652 for (Operation& op : graph_body) {
653 if (failed(CollectSpecialTpuOps(is_op_calling_func_for_cluster, &op,
654 cluster_to_tpu_ops_map, visited_ops))) {
655 graph.emitError()
656 << "Collect special Tpu ops failed: "
657 << "Graph contains op with inconsistent cluster info\n";
658 signalPassFailure();
659 return WalkResult::interrupt();
660 }
661 }
662
663 ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map);
664
665 // Iterate until fixed point on the block, as it may contain multiple
666 // clusters.
667 bool changed = true;
668 while (changed) {
669 changed = false;
670 for (Operation& op : graph_body) {
671 if (failed(MergeIsland(is_op_calling_func_for_cluster,
672 cluster_to_tpu_ops_map, &op, &changed))) {
673 graph.emitError()
674 << "Merging island failed: the TPU cluster likely "
675 << "contains a cycle with non-TPU operations or has "
676 "unsupported ops\n";
677 signalPassFailure();
678 return WalkResult::interrupt();
679 }
680 // If islands were merged, restart scanning the block from the
681 // beginning as we lost track of where to continue.
682 if (changed) break;
683 }
684 }
685
686 // Need to remove the redundant `Identity` ops in the same cluster.
687 // Redundant `Identity` op indicates that no `_replicatation_info`
688 // attribute is attached.
689 EraseIdentityWithNoReplicationInfo(graph_body);
690
691 return WalkResult::advance();
692 });
693 }
694 }
695
696 } // namespace
697
698 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass()699 CreateTFExecutorTPUV1IslandCoarseningPass() {
700 return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
701 }
702
703 } // namespace tf_executor
704 } // namespace mlir
705