1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges the one that contains operation marked to run on TPU.
18 
19 #include <algorithm>
20 #include <iterator>
21 #include <queue>
22 #include <tuple>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/Block.h"  // from @llvm-project
36 #include "mlir/IR/Builders.h"  // from @llvm-project
37 #include "mlir/IR/Location.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
40 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
41 #include "mlir/IR/Visitors.h"  // from @llvm-project
42 #include "mlir/Pass/Pass.h"  // from @llvm-project
43 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
44 #include "mlir/Support/LLVM.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
48 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
50 #include "tensorflow/core/platform/logging.h"
51 
52 #define DEBUG_TYPE "tf-executor-tpu-v1-island-coarsening"
53 
54 namespace mlir {
55 namespace tf_executor {
56 
57 namespace {
58 
59 constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
60 constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster";
61 
62 // This pass is a variant of the island coarsening that is limited to
63 // TPU-annotated operations and intended to preserve backward compatibility with
64 // TFv1.
65 struct TpuV1BridgeExecutorIslandCoarsening
66     : public TF::TpuV1BridgeExecutorIslandCoarseningPassBase<
67           TpuV1BridgeExecutorIslandCoarsening> {
68   void runOnOperation() override;
69 };
70 
71 // Returns name of TPU cluster, if op belongs to a TPU cluster. Otherwise,
72 // returns `llvm::None`.
GetTpuClusterName(Operation * op)73 llvm::Optional<llvm::StringRef> GetTpuClusterName(Operation* op) {
74   if (auto tpu_status = op->getAttrOfType<StringAttr>(kTpuStatusAttr)) {
75     // Borrow cluster name from TPU status (for `TPUCompilationResult` op).
76     return tpu_status.getValue();
77   }
78   auto device_type = op->getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr);
79   if (!device_type || device_type.getValue() != TF::kTpuDevice) {
80     // Op does not belong to a TPU cluster.
81     return llvm::None;
82   }
83   // Op belongs to a TPU cluster.
84   if (auto replication_info =
85           op->getAttrOfType<StringAttr>(TF::kReplicationInfoAttr)) {
86     // Borrow cluster name from replication info.
87     return replication_info.getValue();
88   }
89   // Use special cluster name for non-replicated case.
90   return kNoReplicationCluster;
91 }
92 
HasDataDependencyWithUnscheduledOp(Operation & op,Block * block,SmallPtrSet<Operation *,16> & unscheduled_ops)93 bool HasDataDependencyWithUnscheduledOp(
94     Operation& op, Block* block, SmallPtrSet<Operation*, 16>& unscheduled_ops) {
95   WalkResult ready_to_schedule = op.walk([&](Operation* nested_op) {
96     for (Value operand : nested_op->getOperands()) {
97       Operation* defining_op = operand.getDefiningOp();
98       if (!defining_op) continue;
99       Operation* producer_in_block = block->findAncestorOpInBlock(*defining_op);
100       if (producer_in_block && producer_in_block != &op &&
101           unscheduled_ops.count(producer_in_block)) {
102         // Found an operand that isn't scheduled yet, interrupt the walk.
103         return WalkResult::interrupt();
104       }
105     }
106     return WalkResult::advance();
107   });
108   return ready_to_schedule.wasInterrupted();
109 }
110 
HasControlDependencyWithUnscheduledOp(Operation & op,Block * block,SmallPtrSet<Operation *,16> & unscheduled_ops)111 bool HasControlDependencyWithUnscheduledOp(
112     Operation& op, Block* block, SmallPtrSet<Operation*, 16>& unscheduled_ops) {
113   IslandOp island_op = dyn_cast<IslandOp>(op);
114   if (!island_op) {
115     return false;
116   }
117   for (Value input : island_op.controlInputs()) {
118     Operation* defining_op = input.getDefiningOp();
119     if (!defining_op) continue;
120     Operation* producer_in_block = block->findAncestorOpInBlock(*defining_op);
121     if (producer_in_block && producer_in_block != &op &&
122         unscheduled_ops.count(producer_in_block)) {
123       // Found an operand that isn't scheduled yet, return true.
124       return true;
125     }
126   }
127   return false;
128 }
129 
130 // Sorts the operations in the provided range to enforce dominance.
131 // This is useful after fusing / reorganizing Operations in a block and later
132 // needing to readjust the ordering to ensure dominance.
SortTopologically(Block::iterator begin,Block::iterator end)133 LogicalResult SortTopologically(Block::iterator begin, Block::iterator end) {
134   Block* block = begin->getBlock();
135   // Either sort from `begin` to end of block or both `begin` and
136   // `end` should belong to the same block.
137   assert(end == block->end() ||
138          end->getBlock() == block && "ops must be in the same block");
139 
140   // Track the ops that still need to be scheduled in a set.
141   SmallPtrSet<Operation*, 16> unscheduled_ops;
142   for (Operation& op : llvm::make_range(begin, end))
143     unscheduled_ops.insert(&op);
144 
145   Block::iterator last_scheduled_op = begin;
146   while (!unscheduled_ops.empty()) {
147     bool scheduled_at_least_once = false;
148     // Loop over the ops that are not sorted yet, try to find the ones "ready",
149     // i.e. the ones for which there aren't any operand produced by an op in the
150     // set, and "schedule" it (move it before the last_scheduled_op).
151     for (Operation& op : llvm::make_range(last_scheduled_op, end)) {
152       if (HasDataDependencyWithUnscheduledOp(op, block, unscheduled_ops) ||
153           HasControlDependencyWithUnscheduledOp(op, block, unscheduled_ops)) {
154         continue;
155       }
156       unscheduled_ops.erase(&op);
157       if (Block::iterator(op) != last_scheduled_op)
158         op.moveBefore(block, last_scheduled_op);
159       else
160         ++last_scheduled_op;
161       scheduled_at_least_once = true;
162     }
163     if (!scheduled_at_least_once) return failure();
164   }
165   return success();
166 }
167 
168 // Looks for an IslandOp that wraps a single operation tagged with the
169 // _replication_info attribute, and merges it with all the following operations
170 // in the block. Sets the `changed` boolean to true if any island is merged.
171 // Returns a failure if a cycle prevents the merge from happening correctly
172 // without breaking dominance. The IR is left in invalid state in case of
173 // failure.
CollectCandidateIslands(llvm::function_ref<bool (llvm::StringRef,Operation *)> is_op_calling_func_for_cluster,Operation * op,StringRef cluster_name,SmallPtrSet<Operation *,16> & islands_set,SmallPtrSet<Operation *,16> & wrapped_ops)174 void CollectCandidateIslands(
175     llvm::function_ref<bool(llvm::StringRef, Operation*)>
176         is_op_calling_func_for_cluster,
177     Operation* op, StringRef cluster_name,
178     SmallPtrSet<Operation*, 16>& islands_set,
179     SmallPtrSet<Operation*, 16>& wrapped_ops) {
180   for (Operation& candidate_op : llvm::make_early_inc_range(
181            llvm::make_range(op->getIterator(), op->getBlock()->end()))) {
182     IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
183     if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
184     // Check if we have an operation with the expected attribute.
185     Operation& candidate_wrapped_op = candidate_island.GetBody().front();
186 
187     // The op might be a special TPU input/output op and may have been already
188     // added to the list of islands to be merged.
189     if (wrapped_ops.contains(&candidate_wrapped_op)) {
190       continue;
191     }
192 
193     llvm::Optional<llvm::StringRef> result =
194         GetTpuClusterName(&candidate_wrapped_op);
195     llvm::StringRef candidate_cluster_name;
196     if (result.has_value()) {
197       candidate_cluster_name = result.getValue();
198     } else if (is_op_calling_func_for_cluster(cluster_name,
199                                               &candidate_wrapped_op)) {
200       candidate_cluster_name = cluster_name;
201     }
202     if (candidate_cluster_name != cluster_name) continue;
203 
204     // Add the current op to the set of ops which are planned to be merged into
205     // one cluster.
206     islands_set.insert(candidate_island);
207     wrapped_ops.insert(&candidate_wrapped_op);
208   }
209 }
210 
CreateMergedIsland(IslandOp island,SmallVector<IslandOp,16> & islands,SmallPtrSet<Operation *,16> & wrapped_ops)211 IslandOp CreateMergedIsland(IslandOp island, SmallVector<IslandOp, 16>& islands,
212                             SmallPtrSet<Operation*, 16>& wrapped_ops) {
213   // Compute the result of the merged island, these are the values produced by
214   // the islands that are merged if they have a use in an island not merged,
215   // i.e. a value that escapes.
216   llvm::SmallVector<Type, 4> result_types;
217   for (IslandOp new_op : islands) {
218     for (Value result : new_op.outputs()) {
219       if (llvm::any_of(result.getUsers(), [&](OpOperand user) {
220             return !wrapped_ops.count(user.getOwner());
221           }))
222         result_types.push_back(result.getType());
223     }
224   }
225 
226   IslandOp new_island = OpBuilder(island).create<IslandOp>(
227       island.getLoc(), result_types,
228       /*control=*/ControlType::get(island.getContext()),
229       /*controlInputs=*/island.getOperands());
230   new_island.body().push_back(new Block);
231 
232   // Move the operations in the new island, gather the results of the new yield.
233   Block& island_body = new_island.GetBody();
234   SmallVector<Value, 16> yield_operands;
235   for (IslandOp island : islands) {
236     Operation& wrapped_op = island.GetBody().front();
237     wrapped_op.moveBefore(&island_body, island_body.end());
238 
239     // For every result of the wrapped_op, it needs to get passed to the yield
240     // operation, only if it escapes the island.
241     for (auto result : llvm::zip(island.outputs(), wrapped_op.getResults())) {
242       if (llvm::any_of(std::get<0>(result).getUsers(), [&](OpOperand user) {
243             return !wrapped_ops.count(user.getOwner());
244           }))
245         yield_operands.push_back(std::get<1>(result));
246     }
247   }
248   OpBuilder::atBlockEnd(&island_body)
249       .create<YieldOp>(new_island.getLoc(), yield_operands);
250 
251   // remap results of the new islands to the user outside of the island.
252   int current_result = 0;
253   Value control = new_island.control();
254   for (IslandOp island : islands) {
255     YieldOp yield_op = island.GetYield();
256     for (const auto& idx_result : llvm::enumerate(island.outputs())) {
257       Value result = idx_result.value();
258 
259       bool has_external_use = false;
260       for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
261         if (wrapped_ops.count(use.getOwner()))
262           use.set(yield_op.getOperand(idx_result.index()));
263         else
264           has_external_use = true;
265       }
266       if (has_external_use) {
267         result.replaceAllUsesWith(new_island.getResult(current_result));
268         ++current_result;
269       }
270     }
271     island.control().replaceAllUsesWith(control);
272     island.erase();
273   }
274   return new_island;
275 }
276 
277 // Looks for an IslandOp that wraps a single operation tagged with the
278 // _replication_info attribute, and merges it with all the following operations
279 // in the block. Sets the `changed` boolean to true if any island is merged.
280 // Returns a failure if a cycle prevents the merge from happening correctly
281 // without breaking dominance. The IR is left in invalid state in case of
282 // failure.
MergeIsland(llvm::function_ref<bool (StringRef,Operation *)> is_op_calling_func_for_cluster,llvm::SmallDenseMap<StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map,Operation * op,bool * changed)283 LogicalResult MergeIsland(
284     llvm::function_ref<bool(StringRef, Operation*)>
285         is_op_calling_func_for_cluster,
286     llvm::SmallDenseMap<StringRef, llvm::SmallDenseSet<Operation*>>&
287         cluster_to_tpu_ops_map,
288     Operation* op, bool* changed) {
289   // Find the first island wrapping a single operation with the
290   // `_replication_info` attribute, it'll be used as the root of the algorithm
291   // to find the other operations that are part of the same cluster.
292   IslandOp island = dyn_cast<IslandOp>(*op);
293   if (!island || !island.WrapsSingleOp()) return success();
294   Operation& wrapped_op = island.GetBody().front();
295 
296   llvm::Optional<llvm::StringRef> result = GetTpuClusterName(&wrapped_op);
297   if (!result.has_value()) return success();
298   llvm::StringRef cluster_name = result.getValue();
299 
300   // We found a _replication_info, let's build an island for the full cluster!
301   LLVM_DEBUG(llvm::dbgs() << "Processing candidate island: "
302                           << *island.getOperation() << "\n");
303 
304   // Collect the islands to merge together in this new cluster starting with the
305   // given island.
306   SmallVector<IslandOp, 16> islands;
307   SmallPtrSet<Operation*, 16> islands_set;
308   SmallPtrSet<Operation*, 16> wrapped_ops;
309 
310   CollectCandidateIslands(is_op_calling_func_for_cluster, op, cluster_name,
311                           islands_set, wrapped_ops);
312 
313   if (cluster_to_tpu_ops_map.count(cluster_name)) {
314     for (auto tpu_op : cluster_to_tpu_ops_map[cluster_name]) {
315       islands_set.insert(tpu_op);
316       wrapped_ops.insert(&dyn_cast<IslandOp>(*tpu_op).GetBody().front());
317     }
318   }
319 
320   // Get the sequential order of the candidate islands in the block.
321   // Later we can merge the candidate islands in this order.
322   // The dominance is guaranteed in this order.
323   for (Operation& candidate_op : llvm::make_early_inc_range(
324            llvm::make_range(op->getBlock()->begin(), op->getBlock()->end()))) {
325     IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
326     if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
327     if (islands_set.contains(candidate_island)) {
328       islands.push_back(candidate_island);
329     }
330   }
331 
332   // If no other island was found to merge with the existing one, just move on.
333   if (islands.size() <= 1) return success();
334 
335   *changed = true;
336   Operation* first_op_after = islands.back()->getNextNode();
337 
338   // We create the merged island at the location of the first island that was
339   // merged (excluding special TPU input/output ops).
340   IslandOp new_island = CreateMergedIsland(island, islands, wrapped_ops);
341 
342   // Ensure dominance by sorting the range of islands that were merged.
343   return SortTopologically(Block::iterator(new_island.getOperation()),
344                            Block::iterator(first_op_after));
345 }
346 
347 // Returns all functions that can be reached from TPUPartitionedCall ops.
FindTPUPartitionedCallReachableFunctions(ModuleOp module)348 SmallPtrSet<Operation*, 16> FindTPUPartitionedCallReachableFunctions(
349     ModuleOp module) {
350   SymbolTableCollection table;
351   SymbolUserMap symbol_map(table, module);
352   llvm::DenseMap<func::FuncOp, llvm::DenseSet<func::FuncOp>> caller_callee_map;
353   // Creates work queue for determining reachability below.
354   std::queue<func::FuncOp> function_worklist;
355 
356   for (auto func : module.getOps<func::FuncOp>()) {
357     for (auto user : symbol_map.getUsers(func)) {
358       // Populates work queue with func ops called from TPUPartionedCall.
359       if (llvm::isa<TF::TPUPartitionedCallOp>(user)) {
360         function_worklist.push(func);
361       }
362       // Populates caller to called func map.
363       if (func::FuncOp caller = user->getParentOfType<func::FuncOp>()) {
364         caller_callee_map[caller].insert(func);
365       }
366     }
367   }
368 
369   // Determines reached ops starting from TPUPartionedCall ops
370   // and iteratively descending through called ops.
371   SmallPtrSet<Operation*, 16> reachable_functions;
372   while (!function_worklist.empty()) {
373     func::FuncOp caller = function_worklist.front();
374     function_worklist.pop();
375     if (reachable_functions.insert(caller).second) {
376       for (auto callee : caller_callee_map[caller]) {
377         function_worklist.push(callee);
378       }
379     }
380   }
381   return reachable_functions;
382 }
383 
384 // valid means all the ops in the vector are belong to the same cluster.
is_valid_special_tpu_op(std::vector<IslandOp> & ops,llvm::StringRef cluster_name,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map)385 bool is_valid_special_tpu_op(
386     std::vector<IslandOp>& ops, llvm::StringRef cluster_name,
387     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
388         cluster_to_tpu_op_map) {
389   for (IslandOp op : ops) {
390     Operation* wrapped_op = &op.GetBody().front();
391     llvm::Optional<llvm::StringRef> wrapped_op_cluster_name =
392         GetTpuClusterName(wrapped_op);
393 
394     bool op_has_inconsistent_cluster_name =
395         wrapped_op_cluster_name.has_value() &&
396         !wrapped_op_cluster_name.getValue().equals(cluster_name);
397 
398     if (op_has_inconsistent_cluster_name) {
399       return false;
400     }
401   }
402   return true;
403 }
404 
collect_input_defining_islands(IslandOp op,std::vector<IslandOp> & ops)405 void collect_input_defining_islands(IslandOp op, std::vector<IslandOp>& ops) {
406   Operation* wrapped_op = &op.GetBody().front();
407   for (Value operand : wrapped_op->getOperands()) {
408     IslandOp wrapper = dyn_cast_or_null<IslandOp>(operand.getDefiningOp());
409     if (!wrapper || !wrapper.WrapsSingleOp()) continue;
410     ops.push_back(wrapper);
411   }
412 }
413 
collect_output_users_islands(IslandOp op,std::vector<IslandOp> & ops)414 void collect_output_users_islands(IslandOp op, std::vector<IslandOp>& ops) {
415   for (Value result : op->getResults()) {
416     for (OpOperand use : result.getUsers()) {
417       IslandOp wrapper =
418           dyn_cast_or_null<IslandOp>(use.getOwner()->getParentOp());
419       if (!wrapper || !wrapper.WrapsSingleOp()) continue;
420       ops.push_back(wrapper);
421     }
422   }
423 }
424 
AddSpecialTpuOps(IslandOp candidate_island,llvm::StringRef cluster_name,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map,SmallPtrSetImpl<Operation * > & visited_wrapped_ops,bool incoming)425 bool AddSpecialTpuOps(
426     IslandOp candidate_island, llvm::StringRef cluster_name,
427     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
428         cluster_to_tpu_op_map,
429     SmallPtrSetImpl<Operation*>& visited_wrapped_ops, bool incoming) {
430   std::queue<IslandOp> op_worklist;
431   std::vector<IslandOp> ops;
432 
433   op_worklist.push(candidate_island);
434 
435   while (!op_worklist.empty()) {
436     IslandOp current_op = op_worklist.front();
437     op_worklist.pop();
438     ops.clear();
439     if (incoming) {
440       collect_input_defining_islands(current_op, ops);
441     } else {
442       collect_output_users_islands(current_op, ops);
443     }
444     for (IslandOp wrapper : ops) {
445       Operation* wrapped_op = &wrapper.GetBody().front();
446       std::vector<IslandOp> child_ops;
447       if (incoming) {
448         // Looks at captured operands of `candidate_wrapped_op` to bring special
449         // TPU ops such as tf.TPUReplicatedInput and tf.TPUPartitionedInput into
450         // the island as well. These ops are brought in only if they do not
451         // already have a cluster assigned to them (via `_replication_info`
452         // attribute value).
453         // `tf.Identity` op is also treated as special tpu ops since it can play
454         // a role as connection between `tf.TPUReplicatedInput` or
455         // `tf.TPUPartitionedInput`. For example, we have the follow pseudocode:
456         // %0 = tf_executor.island wraps "tf.OpA" (){_replication_info = 'c'}
457         // %1 = tf_executor.island wraps "tf.Identity(%0)
458         // %2 = tf_executor.island wraps "tf.TPUReplicatedInput"(%1)
459 
460         if (!isa<TF::TPUReplicatedInputOp, TF::TPUPartitionedInputOp,
461                  TF::IdentityOp>(wrapped_op))
462           continue;
463         collect_output_users_islands(wrapper, child_ops);
464       } else {
465         // Looks at the results of `candidate_island` to bring special TPU
466         // ops such as tf.TPUReplicatedOutput and tf.TPUPartitionedOutput into
467         // the island as well. These ops are brought in only if they do not
468         // already have cluster (`_tpu_replicate` attribute) assigned to them.
469         // `tf.Identity` op is also treated as special tpu ops since it can play
470         // a role as connection between `tf.TPUReplicatedOutput` or
471         // `tf.TPUPartitionedInput`.
472         if (!isa<TF::TPUReplicatedOutputOp, TF::TPUPartitionedOutputOp,
473                  TF::IdentityOp>(wrapped_op))
474           continue;
475         collect_input_defining_islands(wrapper, child_ops);
476       }
477       if (!is_valid_special_tpu_op(child_ops, cluster_name,
478                                    cluster_to_tpu_op_map)) {
479         return false;
480       }
481 
482       // Only inputs/outputs that do not have a cluster name assigned are
483       // considered for special handling. Otherwise, island coarsening logic
484       // should be able to handle it.
485       if (wrapped_op->hasAttrOfType<StringAttr>(TF::kReplicationInfoAttr))
486         continue;
487       if (visited_wrapped_ops.contains(wrapped_op)) continue;
488       op_worklist.push(wrapper);
489       cluster_to_tpu_op_map[cluster_name].insert(wrapper);
490       visited_wrapped_ops.insert(wrapped_op);
491     }
492   }
493   return true;
494 }
495 
CollectSpecialTpuOps(llvm::function_ref<bool (llvm::StringRef,Operation *)> is_op_calling_func_for_cluster,Operation * op,llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_op_map,SmallPtrSet<Operation *,16> & visited_wrapped_ops)496 LogicalResult CollectSpecialTpuOps(
497     llvm::function_ref<bool(llvm::StringRef, Operation*)>
498         is_op_calling_func_for_cluster,
499     Operation* op,
500     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
501         cluster_to_tpu_op_map,
502     SmallPtrSet<Operation*, 16>& visited_wrapped_ops) {
503   IslandOp island = dyn_cast<IslandOp>(*op);
504   if (!island || !island.WrapsSingleOp()) return success();
505   Operation& wrapped_op = island.GetBody().front();
506 
507   if (visited_wrapped_ops.contains(&wrapped_op)) return success();
508 
509   llvm::Optional<llvm::StringRef> result = GetTpuClusterName(&wrapped_op);
510   if (!result.has_value()) return success();
511   llvm::StringRef cluster_name = result.getValue();
512 
513   visited_wrapped_ops.insert(&wrapped_op);
514 
515   if (!AddSpecialTpuOps(island, cluster_name, cluster_to_tpu_op_map,
516                         visited_wrapped_ops, /*incoming=*/true)) {
517     return failure();
518   }
519   if (!AddSpecialTpuOps(island, cluster_name, cluster_to_tpu_op_map,
520                         visited_wrapped_ops, /*incoming=*/false)) {
521     return failure();
522   }
523   return success();
524 }
525 
526 // Whenever we find an Identity op that is unqualified, we remove this Identity
527 // op from the list `tpu_ops`. An unqualified Identity op indicates either its
528 // inputs or its outputs do not belong to the same cluster.
ExcludeIdentityOp(llvm::SmallDenseSet<Operation * > & tpu_ops,llvm::StringRef & target_cluster_name,bool incoming)529 bool ExcludeIdentityOp(llvm::SmallDenseSet<Operation*>& tpu_ops,
530                        llvm::StringRef& target_cluster_name, bool incoming) {
531   for (auto iter = tpu_ops.begin(); iter != tpu_ops.end(); iter++) {
532     auto island_op = llvm::dyn_cast<IslandOp>(*iter);
533     if (llvm::dyn_cast_or_null<TF::IdentityOp>(island_op.GetBody().front())) {
534       if (island_op.outputs().use_empty()) {
535         tpu_ops.erase(iter);
536         return true;
537       }
538       std::vector<IslandOp> ops;
539       if (incoming) {
540         collect_output_users_islands(island_op, ops);
541       } else {
542         collect_input_defining_islands(island_op, ops);
543       }
544       for (IslandOp wrapper : ops) {
545         Operation* wrapped_op = &wrapper.GetBody().front();
546         auto cluster_name = GetTpuClusterName(wrapped_op);
547         if (cluster_name.hasValue() &&
548             cluster_name.getValue() != target_cluster_name) {
549           tpu_ops.erase(iter);
550           return true;
551         }
552         if (!cluster_name.hasValue() &&
553             !tpu_ops.count(wrapper.getOperation())) {
554           tpu_ops.erase(iter);
555           return true;
556         }
557       }
558     }
559   }
560   return false;
561 }
562 
ExcludeUnqualifiedIdentityOp(llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map,bool incoming)563 void ExcludeUnqualifiedIdentityOp(
564     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
565         cluster_to_tpu_ops_map,
566     bool incoming) {
567   for (auto& [target_cluster_name, tpu_ops] : cluster_to_tpu_ops_map) {
568     bool changed = true;
569     while (changed) {
570       changed = ExcludeIdentityOp(tpu_ops, target_cluster_name, incoming);
571     }
572   }
573 }
574 
ExcludeUnqualifiedIdentityOp(llvm::SmallDenseMap<llvm::StringRef,llvm::SmallDenseSet<Operation * >> & cluster_to_tpu_ops_map)575 void ExcludeUnqualifiedIdentityOp(
576     llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>&
577         cluster_to_tpu_ops_map) {
578   ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map, /*incoming=*/true);
579   ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map, /*incoming=*/false);
580 }
581 
582 // Erase Identity op which does not contain `_replication_info` in the merged
583 // island.
EraseIdentityWithNoReplicationInfo(Block & graph_body)584 void EraseIdentityWithNoReplicationInfo(Block& graph_body) {
585   for (Operation& island_op : graph_body) {
586     IslandOp island = dyn_cast<IslandOp>(island_op);
587     if (!island || island.WrapsSingleOp()) continue;
588     for (Operation& op : llvm::make_early_inc_range(island.GetBody())) {
589       llvm::Optional<llvm::StringRef> cluster_name = GetTpuClusterName(&op);
590       if (cluster_name.hasValue()) continue;
591       if (auto identity_op = llvm::dyn_cast_or_null<TF::IdentityOp>(op)) {
592         auto identity_input = identity_op.input();
593         auto output = identity_op.output();
594         output.replaceAllUsesWith(identity_input);
595         identity_op.erase();
596       }
597     }
598   }
599 }
600 
runOnOperation()601 void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
602   SymbolTable symbol_table(getOperation());
603 
604   // Map tpu cluster names to the functions that contain operations for this
605   // cluster.
606   DenseMap<StringRef, DenseSet<func::FuncOp>> tpu_funcs;
607   for (func::FuncOp func_op : getOperation().getOps<func::FuncOp>()) {
608     func_op.walk([&](Operation* op) {
609       llvm::Optional<llvm::StringRef> cluster_name_opt = GetTpuClusterName(op);
610       if (cluster_name_opt.has_value()) {
611         tpu_funcs[cluster_name_opt.getValue()].insert(func_op);
612       }
613     });
614   }
615 
616   // Return true if the operation is containing a reference to a function
617   // containing operations for this cluster.
618   auto is_op_calling_func_for_cluster = [&](llvm::StringRef cluster,
619                                             Operation* op) {
620     auto funcs_for_cluster = tpu_funcs.find(cluster);
621     assert(funcs_for_cluster != tpu_funcs.end());
622     assert(!funcs_for_cluster->second.empty());
623     if (funcs_for_cluster->second.size() == 1) return false;
624     for (NamedAttribute attr : op->getAttrs()) {
625       auto symbol_ref = attr.getValue().dyn_cast<FlatSymbolRefAttr>();
626       if (!symbol_ref) continue;
627       func::FuncOp callee =
628           symbol_table.lookup<func::FuncOp>(symbol_ref.getValue());
629       if (!callee) continue;
630       if (funcs_for_cluster->second.count(callee)) return true;
631     }
632     return false;
633   };
634 
635   // Populates skip set with functions reachable from TPUPartionedCall ops.
636   const auto functions_to_skip =
637       FindTPUPartitionedCallReachableFunctions(getOperation());
638   for (func::FuncOp func_op : getOperation().getOps<func::FuncOp>()) {
639     if (functions_to_skip.contains(func_op)) {
640       OpBuilder builder(func_op);
641       // Mark this function as being skipped in island outlining.
642       func_op->setAttr(mlir::TF::kSkipIslandOutlining,
643                        builder.getBoolAttr(true));
644       continue;
645     }
646 
647     func_op.walk([&](GraphOp graph) {
648       Block& graph_body = graph.GetBody();
649       llvm::SmallDenseMap<llvm::StringRef, llvm::SmallDenseSet<Operation*>>
650           cluster_to_tpu_ops_map;
651       SmallPtrSet<Operation*, 16> visited_ops;
652       for (Operation& op : graph_body) {
653         if (failed(CollectSpecialTpuOps(is_op_calling_func_for_cluster, &op,
654                                         cluster_to_tpu_ops_map, visited_ops))) {
655           graph.emitError()
656               << "Collect special Tpu ops failed: "
657               << "Graph contains op with inconsistent cluster info\n";
658           signalPassFailure();
659           return WalkResult::interrupt();
660         }
661       }
662 
663       ExcludeUnqualifiedIdentityOp(cluster_to_tpu_ops_map);
664 
665       // Iterate until fixed point on the block, as it may contain multiple
666       // clusters.
667       bool changed = true;
668       while (changed) {
669         changed = false;
670         for (Operation& op : graph_body) {
671           if (failed(MergeIsland(is_op_calling_func_for_cluster,
672                                  cluster_to_tpu_ops_map, &op, &changed))) {
673             graph.emitError()
674                 << "Merging island failed: the TPU cluster likely "
675                 << "contains a cycle with non-TPU operations or has "
676                    "unsupported ops\n";
677             signalPassFailure();
678             return WalkResult::interrupt();
679           }
680           // If islands were merged, restart scanning the block from the
681           // beginning as we lost track of where to continue.
682           if (changed) break;
683         }
684       }
685 
686       // Need to remove the redundant `Identity` ops in the same cluster.
687       // Redundant `Identity` op indicates that no `_replicatation_info`
688       // attribute is attached.
689       EraseIdentityWithNoReplicationInfo(graph_body);
690 
691       return WalkResult::advance();
692     });
693   }
694 }
695 
696 }  // namespace
697 
698 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass()699 CreateTFExecutorTPUV1IslandCoarseningPass() {
700   return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
701 }
702 
703 }  // namespace tf_executor
704 }  // namespace mlir
705