1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/iterator_range.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/IR/Attributes.h"  // from @llvm-project
36 #include "mlir/IR/Builders.h"  // from @llvm-project
37 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/Types.h"  // from @llvm-project
40 #include "mlir/IR/Value.h"  // from @llvm-project
41 #include "mlir/Pass/Pass.h"  // from @llvm-project
42 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
43 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
44 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
52 
53 namespace mlir {
54 namespace TFTPU {
55 
56 namespace {
57 
58 constexpr llvm::StringRef kDeviceAttr = "device";
59 constexpr llvm::StringRef kNameAttr = "name";
60 constexpr llvm::StringRef kNumCoresPerReplicaAttr = "num_cores_per_replica";
61 constexpr llvm::StringRef kNumReplicasAttr = "num_replicas";
62 constexpr llvm::StringRef kMirroredVariableIndicesAttr =
63     "_mirrored_variable_indices";
64 constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster";
65 
66 constexpr llvm::StringRef kBadReplicateInfoAttrMsg =
67     "requires '_replication_info' string attribute";
68 
69 // Mapping for `_replication_info` attribute to TPUReplicateMetadata attributes.
70 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
71 
72 // A set of operations. We use a `SmallSetVector` in order to have deterministic
73 // traversal order (= insertion order), independent of the pointer keys.
74 using OpSetVector = llvm::SmallSetVector<Operation*, 8>;
75 
76 // Mapping for `_replication_info` attribute to ops of a cluster.
77 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, OpSetVector, 8>;
78 
79 struct TPUClusterFormationPass
80     : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anon1e87b7740111::TPUClusterFormationPass81   void getDependentDialects(DialectRegistry& registry) const override {
82     registry.insert<tf_device::TensorFlowDeviceDialect>();
83   }
84 
85   void runOnOperation() override;
86 };
87 
88 // Creates a mapping from the TPUReplicateMetadata ops `_replication_info`
89 // attribute to its attributes and removes the ops. If multiple
90 // TPUReplicateMetadata ops have the same `_replication_info` attribute, an
91 // error will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)92 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
93   // Just look at top-level operations in the block (not nested ones)
94   for (Operation& op : llvm::make_early_inc_range(*block)) {
95     auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
96     if (!metadata_op) continue;
97 
98     NamedAttrList attrs(metadata_op->getAttrDictionary());
99 
100     // Missing or bad `_replication_info` attribute.
101     auto replication_info_attr = attrs.get(TF::kReplicationInfoAttr);
102     if (!replication_info_attr)
103       return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
104 
105     auto replication_info_attr_str =
106         replication_info_attr.dyn_cast<StringAttr>();
107     if (!replication_info_attr_str ||
108         replication_info_attr_str.getValue().empty())
109       return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
110 
111     // Remove `name` attribute.
112     attrs.erase(StringAttr::get(metadata_op.getContext(), kNameAttr));
113 
114     auto it = metadata_map->try_emplace(replication_info_attr_str.getValue(),
115                                         std::move(attrs));
116 
117     // There are multiple TPUReplicateMetadata ops with the same
118     // `_replication_info` attribute.
119     if (!it.second) {
120       return metadata_op.emitError()
121              << "multiple TPUReplicateMetadata ops with the same '"
122              << TF::kReplicationInfoAttr << "' attribute '"
123              << replication_info_attr_str.getValue() << "' found";
124     }
125     metadata_op.erase();
126   }
127   return success();
128 }
129 
130 // Collects and clusters ops either based on `_replication_info` attribute
131 // (replicated case) or using one single cluster (non-replicated case). Also
132 // sets `device_type` if there is any cluster (note that the device type must be
133 // unique, otherwise we emit an error).
134 // Returns an error in case of invalid compilation or replication attribute(s).
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters,std::string & device_type)135 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters,
136                                         std::string& device_type) {
137   bool has_replicated_compiled_op = false;
138   bool has_non_replicated_compiled_op = false;
139   // Use ordered set here to make error message below deterministic.
140   std::set<llvm::StringRef> device_types;
141   for (Operation& op : *block) {
142     LogicalResult result = TF::HasValidCompilationAndReplicationAttributes(op);
143     if (failed(result)) return result;
144 
145     // Collect device types which currently must be consistent per block
146     // (checked later).
147     auto device_type_attr =
148         op.getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr);
149     if (device_type_attr) device_types.insert(device_type_attr);
150 
151     if (op.hasAttr(TF::kReplicationInfoAttr)) {
152       // For replicated case, borrow cluster structure from replication info.
153       // Following condition is already checked in
154       // `HasValidCompilationAndReplicationAttributes` above, assert here for
155       // documentation and to avoid breakage when that function is changed.
156       assert(op.hasAttr(TF::kCompileDeviceTypeAttr));
157       has_replicated_compiled_op = true;
158       auto attr = op.getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
159       auto it = clusters->try_emplace(attr.getValue());
160       it.first->getSecond().insert(&op);
161     } else if (op.hasAttr(TF::kCompileDeviceTypeAttr)) {
162       // For non-replicated case, assume one cluster per block (in line with
163       // Framework behavior).
164       has_non_replicated_compiled_op = true;
165       auto it = clusters->try_emplace(kNoReplicationCluster);
166       it.first->getSecond().insert(&op);
167     }
168   }
169   // Do some checks for unsupported cases.
170   if (has_replicated_compiled_op && has_non_replicated_compiled_op) {
171     return block->getParentOp()->emitError()
172            << "found mixed replicated and non-replicated compiled ops in same "
173               "block which is not supported";
174   }
175   if (device_types.size() > 1) {
176     return block->getParentOp()->emitError()
177            << "found different '" << TF::kCompileDeviceTypeAttr
178            << "' attribute values (" << llvm::join(device_types, ",")
179            << ") in same block which is not supported";
180   }
181   if (!clusters->empty()) {
182     // Note that for size < 1 we shouldn't have any cluster while for size > 1
183     // we should have returned with an error above.
184     assert(device_types.size() == 1);
185     device_type = device_types.begin()->str();
186   }
187   return success();
188 }
189 
190 // Returns true iff `op` has a direct control dependency from (`incoming` ==
191 // true) or to (`incoming` == false) any op in `cluster_ops` or
192 // `cluster_dependent_ops`.
hasOpClusterControlDependency(Operation * op,bool incoming,const OpSetVector & cluster_ops,const OpSetVector & cluster_dependent_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)193 bool hasOpClusterControlDependency(
194     Operation* op, bool incoming, const OpSetVector& cluster_ops,
195     const OpSetVector& cluster_dependent_ops,
196     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
197   auto filter = [&](Operation* other_op) {
198     return cluster_ops.contains(other_op) ||
199            cluster_dependent_ops.contains(other_op);
200   };
201   return incoming ? !side_effect_analysis.DirectControlPredecessors(op, filter)
202                          .empty()
203                   : !side_effect_analysis.DirectControlSuccessors(op, filter)
204                          .empty();
205 }
206 
207 // Returns true iff `op` has a direct data dependency from (`incoming` == true
208 // or to (`incoming` == false) any op in `cluster_ops` or
209 // `cluster_dependent_ops`.
hasOpClusterDataDependency(Operation * op,bool incoming,const OpSetVector & cluster_ops,const OpSetVector & cluster_dependent_ops)210 bool hasOpClusterDataDependency(Operation* op, bool incoming,
211                                 const OpSetVector& cluster_ops,
212                                 const OpSetVector& cluster_dependent_ops) {
213   auto result = op->walk([&](Operation* inner_op) {
214     ValueRange values = incoming ? ValueRange(inner_op->getOperands())
215                                  : ValueRange(inner_op->getResults());
216     llvm::SmallVector<Operation*, 4> candidates;
217     for (Value value : values) {
218       if (incoming) {
219         candidates = {value.getDefiningOp()};
220       } else {
221         candidates.assign(value.getUsers().begin(), value.getUsers().end());
222       }
223       for (Operation* candidate_op : candidates) {
224         if (cluster_ops.contains(candidate_op) ||
225             cluster_dependent_ops.contains(candidate_op)) {
226           return WalkResult::interrupt();
227         }
228       }
229     }
230     return WalkResult::advance();
231   });
232   return result.wasInterrupted();
233 }
234 
235 // Collects ops that need to be moved behind the cluster due to data or control
236 // dependencies.
CollectClusterSuccessorOps(Block * block,const OpSetVector & cluster_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)237 llvm::SmallSetVector<Operation*, 8> CollectClusterSuccessorOps(
238     Block* block, const OpSetVector& cluster_ops,
239     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
240   OpSetVector cluster_predecessor_ops;
241   OpSetVector cluster_successor_ops;
242 
243   // Collect non-cluster ops that have a dependency to the cluster. For this
244   // traverse all ops from last to first cluster op and keep track of in-between
245   // non-cluster ops that have some outgoing (transitive) dependency to some
246   // cluster op (`cluster_predecessor_ops`).
247   auto rfront = Block::reverse_iterator(cluster_ops.front());
248   auto rback = Block::reverse_iterator(cluster_ops.back());
249   for (Operation& op : llvm::make_range(rback, rfront)) {
250     if (cluster_ops.contains(&op)) continue;
251     bool has_dependency_to_cluster =
252         hasOpClusterDataDependency(&op, /*incoming=*/false, cluster_ops,
253                                    cluster_predecessor_ops) ||
254         hasOpClusterControlDependency(&op, /*incoming=*/false, cluster_ops,
255                                       cluster_predecessor_ops,
256                                       side_effect_analysis);
257     if (has_dependency_to_cluster) cluster_predecessor_ops.insert(&op);
258   }
259   // Collect non-cluster ops that have a dependency from the cluster. For this
260   // traverse all ops from first to last cluster op and keep track of in-between
261   // non-cluster ops that have some incoming (transitive) dependency from some
262   // cluster op (`cluster_successor_ops`).
263   auto front = Block::iterator(cluster_ops.front());
264   auto back = Block::iterator(cluster_ops.back());
265   for (Operation& op : llvm::make_range(front, back)) {
266     if (cluster_ops.contains(&op)) continue;
267     bool has_dependency_from_cluster =
268         hasOpClusterDataDependency(&op, /*incoming=*/true, cluster_ops,
269                                    cluster_successor_ops) ||
270         hasOpClusterControlDependency(&op, /*incoming=*/true, cluster_ops,
271                                       cluster_successor_ops,
272                                       side_effect_analysis);
273     if (has_dependency_from_cluster) {
274       if (cluster_predecessor_ops.contains(&op)) {
275         // Op has a dependency from and to the cluster which is invalid. Instead
276         // of erroring out we don't add the op to `cluster_successor_ops` which
277         // is in line with previous behavior when certain control dependencies
278         // were not considered.
279         // TODO(b/216706460) Establish some contract here: Should we expect only
280         // valid clusters, or should we split clusters accordingly? The latter
281         // might have runtime impact for existing models.
282         // We should make this message an error once there is such a contract
283         // and once existing cases have been fixed.
284         op.emitWarning()
285             << "op has cyclic dependency with a compilation cluster";
286       } else {
287         cluster_successor_ops.insert(&op);
288       }
289     }
290   }
291   return cluster_successor_ops;
292 }
293 
294 // Collects results and associated types of the cluster that are used outside of
295 // the cluster. These results and types are used to create the clusters
296 // `tf_device.cluster` and associated terminator. Results that have no uses
297 // outside of the cluster (i.e. results of ops in the cluster are only consumed
298 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const OpSetVector & cluster_ops)299 llvm::SmallVector<Value, 8> CollectClusterResults(
300     Block* block, const OpSetVector& cluster_ops) {
301   llvm::SmallVector<Value, 8> results;
302 
303   for (Operation* op : cluster_ops) {
304     for (Value result : op->getResults()) {
305       for (Operation* user : result.getUsers()) {
306         // Check if user is not an op in the cluster.
307         if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
308           results.push_back(result);
309           break;
310         }
311       }
312     }
313   }
314 
315   return results;
316 }
317 
318 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const OpSetVector & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > cluster_successor_ops)319 tf_device::ClusterOp CreateClusterOp(
320     Block* block, const OpSetVector& cluster_ops, llvm::ArrayRef<Value> results,
321     llvm::ArrayRef<Operation*> cluster_successor_ops) {
322   // `tf_device.cluster` will be placed at where the last op of the cluster is.
323   Operation* last_cluster_op = cluster_ops.back();
324   OpBuilder builder(last_cluster_op);
325 
326   llvm::SmallVector<Type, 8> result_types;
327   for (Value result : results) result_types.push_back(result.getType());
328   auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
329                                                       result_types);
330 
331   Block* body = new Block;
332   cluster.body().push_back(body);
333 
334   // Move cluster ops to the cluster body. Also remove `_replication_info` and
335   // `device` attribute from ops in the cluster when that information is
336   // redundant will the `tf_device.cluster`. Do this for all ops including
337   // nested ops.
338   for (Operation* cluster_op : cluster_ops) {
339     cluster_op->moveBefore(body, body->end());
340     cluster_op->walk([&](Operation* inner_op) {
341       inner_op->removeAttr(TF::kReplicationInfoAttr);
342       inner_op->removeAttr(TF::kCompileDeviceTypeAttr);
343 
344       if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
345         // Preserve device attribute if the op is placed on a replicated core
346         // device. Device attribute is used to infer the appropriate sharding
347         // within TPUs for this op.
348         // TODO(b/183598857): Use explicit sharding ops from the front-end.
349         // For example, dequeue ops generated by
350         // tensorflow/python/tpu/tpu_feed.py
351         if (!tensorflow::IsTPUReplicatedCore(attr.getValue())) {
352           inner_op->removeAttr(kDeviceAttr);
353         }
354       }
355     });
356   }
357 
358   // Add terminator.
359   builder.setInsertionPointToEnd(body);
360   builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
361 
362   // Replaces uses of cluster ops results outside of cluster with the associated
363   // `tf_device.cluster` results.
364   for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
365     Value old_ret = std::get<0>(ret_vals);
366     Value new_ret = std::get<1>(ret_vals);
367     for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
368       Operation* user = use.getOwner();
369       if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
370     }
371   }
372 
373   // Move ops that depend on something in the cluster behind the cluster.
374   Operation* op_after_cluster = cluster.getOperation()->getNextNode();
375   for (Operation* op : cluster_successor_ops) op->moveBefore(op_after_cluster);
376   return cluster;
377 }
378 
379 // Creates a `tf_device.replicate` to represent replication for the cluster, if
380 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)381 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
382                                int num_cores_per_replica) {
383   // No need to replicate.
384   if (num_replicas == 1) return success();
385 
386   if (num_replicas < 1)
387     return cluster.emitError() << "requires '" << kNumReplicasAttr
388                                << "' int attribute to be at least 1";
389 
390   LogicalResult status = success();
391   // Collect all used TPUReplicatedInput ops.
392   llvm::SmallVector<Operation*, 8> replicated_input_ops;
393   mlir::visitUsedValuesDefinedAbove(
394       cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
395         Operation* def = operand->get().getDefiningOp();
396         if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
397           replicated_input_ops.push_back(def);
398         // When model parallelism is used in conjunction with data parallelism
399         // for resource inputs, we need to collect the per replica resource
400         // inputs from input to `tf.TPUPartitionedInput` ops.
401         if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
402           if (pi->getNumOperands() != num_cores_per_replica)
403             status = pi.emitOpError()
404                      << "requires " << num_cores_per_replica
405                      << " operands but found " << pi->getNumOperands();
406           for (auto operand : pi.inputs()) {
407             if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
408                     operand.getDefiningOp()))
409               replicated_input_ops.push_back(operand.getDefiningOp());
410           }
411         }
412       });
413 
414   if (failed(status)) return failure();
415 
416   // Indices of the replicate op's arguments that are mirrored variables.
417   llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
418 
419   // Check if number of operands of each used TPUReplicatedInput op matches
420   // `num_replicas` or 1. Collect all their operands and associated type for
421   // creating the replicate op.
422   llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
423   llvm::SmallVector<Value, 8> packed_inputs;
424   llvm::SmallVector<Operation*, 8> replicated_ops;
425   llvm::SmallVector<Operation*, 8> packed_ops;
426   for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
427     auto input = pos_and_input.value();
428     bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
429     const int num_operands = input->getNumOperands();
430     int num_inputs = is_packed ? 1 : num_replicas;
431     if (num_operands != num_inputs)
432       return input->emitOpError() << "requires " << num_inputs << " operands";
433     if (is_packed) {
434       packed_inputs.push_back(input->getOperand(0));
435       packed_ops.push_back(input);
436     } else {
437       replicated_inputs.push_back(
438           {input->getOperands(), input->getOperand(0).getType()});
439       replicated_ops.push_back(input);
440     }
441   }
442 
443   // Create `ordered_tpu_replicate_inputs` which constains the final ordered
444   // replicate inputs. All packed arguments are moved to the end of the arg
445   // list.
446   llvm::SmallVector<Operation*, 8> ordered_tpu_replicate_inputs =
447       replicated_ops;
448   ordered_tpu_replicate_inputs.append(packed_ops.begin(), packed_ops.end());
449 
450   // Assign `mirrored_variable_indices` based on the ordered replicated inputs.
451   for (const auto& pos_and_input :
452        llvm::enumerate(ordered_tpu_replicate_inputs)) {
453     auto tpu_replicated_input =
454         llvm::cast<TF::TPUReplicatedInputOp>(pos_and_input.value());
455     if (tpu_replicated_input.is_mirrored_variable()) {
456       mirrored_variable_indices.push_back(pos_and_input.index());
457     }
458   }
459 
460   // Create replicate op.
461   OpBuilder builder(cluster);
462   auto replicate_op = builder.create<tf_device::ReplicateOp>(
463       cluster.getLoc(), num_replicas,
464       llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
465       replicated_inputs, packed_inputs, cluster.getResultTypes());
466 
467   if (!mirrored_variable_indices.empty())
468     replicate_op->setAttr(kMirroredVariableIndicesAttr,
469                           builder.getI64ArrayAttr(mirrored_variable_indices));
470 
471   // Replace replicated cluster results with replicate op results.
472   for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
473     Value result = result_and_idx.value();
474     int idx = result_and_idx.index();
475     auto replicate_outputs = llvm::make_range(
476         std::next(replicate_op.result_begin(), idx * num_replicas),
477         std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
478 
479     for (auto& use : llvm::make_early_inc_range(result.getUses())) {
480       Operation* def = use.getOwner();
481       if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
482         // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
483         // replica output. Certain Graphs under V1 create `tf.Identity` users of
484         // replicated ops to pin the TPU computation for execution.
485         use.set(*replicate_outputs.begin());
486         continue;
487       }
488 
489       const int def_num_results = def->getNumResults();
490       if (def_num_results != num_replicas)
491         return def->emitOpError() << "requires " << num_replicas << " results";
492 
493       def->replaceAllUsesWith(replicate_outputs);
494     }
495   }
496 
497   // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
498   // `tf_device.replicate` later.
499   llvm::SmallSet<Operation*, 4> partitioned_inputs;
500   for (auto input_and_block_arg :
501        llvm::zip(ordered_tpu_replicate_inputs,
502                  replicate_op.GetBody().getArguments())) {
503     Operation* input = std::get<0>(input_and_block_arg);
504     Value block_arg = std::get<1>(input_and_block_arg);
505     mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
506                                      cluster.body());
507     // Update replicated input use in tf.TPUPartitionedInput op.
508     for (auto& use : input->getUses()) {
509       auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
510       if (pi) {
511         pi.setOperand(use.getOperandNumber(), block_arg);
512         partitioned_inputs.insert(pi.getOperation());
513       }
514     }
515   }
516 
517   // Create terminator for replicate op and move `tf_device.cluster` and
518   // `tf.TPUPartitionedInput`(s) into replicate body.
519   builder.setInsertionPointToEnd(&replicate_op.GetBody());
520   auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
521                                                        cluster.getResults());
522   for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
523 
524   cluster.getOperation()->moveBefore(return_op);
525 
526   return success();
527 }
528 
SetNoReplicationClusterAttrs(tf_device::ClusterOp cluster,llvm::StringRef device_type)529 void SetNoReplicationClusterAttrs(tf_device::ClusterOp cluster,
530                                   llvm::StringRef device_type) {
531   OpBuilder builder(cluster);
532   cluster->setAttr(TF::kReplicationInfoAttr,
533                    builder.getStringAttr(kNoReplicationCluster));
534   cluster->setAttr(TF::kCompileDeviceTypeAttr,
535                    builder.getStringAttr(device_type));
536 
537   // TODO(b/229992058) Propagate `allow_soft_placement` (and other attributes?)
538   // instead of hard-coding.
539   cluster->setAttr("allow_soft_placement", builder.getBoolAttr(true));
540   cluster->setAttr("topology", builder.getStringAttr(""));
541   cluster->setAttr("num_cores_per_replica",
542                    builder.getIntegerAttr(builder.getI32Type(), 1));
543   cluster->setAttr("device_assignment", builder.getArrayAttr({}));
544   cluster->setAttr("use_spmd_for_xla_partitioning", builder.getBoolAttr(false));
545   cluster->setAttr("step_marker_location", builder.getStringAttr(""));
546 }
547 
548 // Forms compilation clusters in `block`. If the block contains a
549 // `TPUReplicateMetadata` op, then we form clusters according to
550 // `_replication_info` values (ops with same value go to same cluster).
551 // Otherwise, in the non-replicated case, we build one compilation cluster per
552 // block.
553 //
554 // We do this in following steps:
555 //   1. Find `TPUReplicateMetadata` op in `block` (might not exist).
556 //   2. Collect and group cluster ops (either based on `_replication_info`
557 //      attributes or forming one single cluster).
558 //   3. Find external uses of cluster ops.
559 //   4. Create `tf_device.cluster` with results consisting of the external uses
560 //      of cluster ops determined at 3.
561 //   5. Move cluster ops to `tf_device.cluster` body.
562 //   6. Replace external uses of cluster ops uses with `tf_device.cluster`
563 //      results.
564 //   7. Move users from 2 to after the `tf_device.cluster`.
565 //   8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
566 //      attribute `num_replicas` is greater than 1.
567 //   9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::SideEffectAnalysis::Info & side_effect_analysis)568 LogicalResult FormClustersInBlock(
569     Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
570   MetadataMap metadata_map;
571   LogicalResult result = CollectMetadata(block, &metadata_map);
572   if (failed(result)) return result;
573 
574   // If there is no TPUReplicateMetadata op in this block, process blocks in
575   // regions attached to the op's in the block.
576   if (metadata_map.empty()) {
577     for (Operation& op : *block) {
578       for (Region& region : op.getRegions()) {
579         if (!llvm::hasSingleElement(region))
580           return op.emitOpError("Expected single block region");
581         if (failed(FormClustersInBlock(&region.front(), side_effect_analysis)))
582           return failure();
583       }
584     }
585   }
586 
587   ClusterMap clusters;
588   std::string device_type;
589   result = CollectAndGroupClusterOps(block, &clusters, device_type);
590   if (failed(result)) return result;
591 
592   for (const auto& cluster_metadata_and_ops : clusters) {
593     const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
594 
595     bool has_replication =
596         cluster_metadata_and_ops.getFirst() != kNoReplicationCluster;
597     auto cluster_metadata =
598         metadata_map.find(cluster_metadata_and_ops.getFirst());
599 
600     // No TPUReplicateMetadata for a `_replication_info` attribute.
601     if (has_replication && cluster_metadata == metadata_map.end()) {
602       block->getParentOp()->emitWarning()
603           << "TPUReplicateMetadata for associated '" << TF::kReplicationInfoAttr
604           << "' attribute '" << cluster_metadata_and_ops.getFirst()
605           << "' is missing";
606       continue;
607     }
608 
609     OpSetVector cluster_successor_ops =
610         CollectClusterSuccessorOps(block, cluster_ops, side_effect_analysis);
611 
612     llvm::SmallVector<Value, 8> results =
613         CollectClusterResults(block, cluster_ops);
614 
615     tf_device::ClusterOp cluster = CreateClusterOp(
616         block, cluster_ops, results, cluster_successor_ops.getArrayRef());
617 
618     if (!has_replication) {
619       SetNoReplicationClusterAttrs(cluster, device_type);
620       continue;
621     }
622     // Determine `num_replicas`.
623     auto num_replicas_attr =
624         cluster_metadata->getSecond().get(kNumReplicasAttr);
625     if (!num_replicas_attr || !num_replicas_attr.isa<mlir::IntegerAttr>())
626       return cluster.emitError()
627              << "requires '" << kNumReplicasAttr << "' int attribute";
628     int num_replicas = num_replicas_attr.cast<mlir::IntegerAttr>().getInt();
629 
630     // Determine `num_cores_per_replica`.
631     int num_cores_per_replica = 1;
632     auto num_cores_per_replica_attr =
633         cluster_metadata->getSecond()
634             .get(kNumCoresPerReplicaAttr)
635             .dyn_cast_or_null<mlir::IntegerAttr>();
636     if (num_cores_per_replica_attr)
637       num_cores_per_replica = num_cores_per_replica_attr.getInt();
638     if (failed(ReplicateCluster(cluster, num_replicas, num_cores_per_replica)))
639       return failure();
640 
641     // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
642     cluster->setAttrs(
643         cluster_metadata->second.getDictionary(cluster.getContext()));
644     // Exclude `num_replicas` as cluster should be replicated if necessary.
645     cluster->removeAttr(kNumReplicasAttr);
646   }
647 
648   return success();
649 }
650 
FormClustersInFunction(func::FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)651 LogicalResult FormClustersInFunction(
652     func::FuncOp func,
653     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
654   if (!llvm::hasSingleElement(func))
655     return func.emitOpError("Expecting a single block function");
656 
657   if (failed(FormClustersInBlock(&func.front(), side_effect_analysis)))
658     return failure();
659 
660   // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
661   auto remove_result = func.walk([&](Operation* op) {
662     if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
663       return WalkResult::advance();
664 
665     // Forward operand to result. When `num_replicas` attribute is 1, no
666     // `tf_device.replicate` is created and replicated (1) operands/results are
667     // untouched.
668     if (op->getNumOperands() == 1 && op->getNumResults() == 1)
669       op->getResult(0).replaceAllUsesWith(op->getOperand(0));
670 
671     // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
672     // `num_replicas` to 1.
673     if (!op->use_empty()) {
674       op->emitOpError() << "is expected to have no uses, but it is operand#"
675                         << op->use_begin()->getOperandNumber() << " of "
676                         << *op->use_begin()->getOwner();
677       return WalkResult::interrupt();
678     }
679 
680     op->erase();
681 
682     return WalkResult::advance();
683   });
684 
685   return failure(remove_result.wasInterrupted());
686 }
687 
runOnOperation()688 void TPUClusterFormationPass::runOnOperation() {
689   // Attributes on tf.Constant aren't reliable: CSE will merge ConstantLike ops
690   // with the same value (but different attributes!) into the same tf.Const
691   // definition, potentially leading to bogus _replication_info attributes. So
692   // we just scrub all tf.Constants of all extra attributes.
693   // TODO(kramm): Remove this once tf.Const's folder is aware of extra
694   // attributes.
695   auto value_str_attr = StringAttr::get(&getContext(), "value");
696   getOperation().walk([&](TF::ConstOp cst) {
697     auto dict = cst->getAttrDictionary();
698     if (dict.size() == 1) {
699       return;  // Optimization. Assume the one attribute is "value".
700     }
701     // Recreate the attributes dictionary to only contain "value".
702     NamedAttrList attributes;
703     attributes.append(NamedAttribute(value_str_attr, cst->getAttr("value")));
704     cst->setAttrs(attributes.getDictionary(&getContext()));
705   });
706 
707   auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
708   for (auto func : getOperation().getOps<func::FuncOp>())
709     if (!func.isExternal() &&
710         failed(FormClustersInFunction(
711             func, side_effect_analysis.GetAnalysisForFunc(func))))
712       return signalPassFailure();
713 }
714 }  // anonymous namespace
715 
CreateTPUClusterFormationPass()716 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
717   return std::make_unique<TPUClusterFormationPass>();
718 }
719 
720 }  // namespace TFTPU
721 }  // namespace mlir
722