1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringExtras.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/iterator_range.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/IR/Attributes.h" // from @llvm-project
36 #include "mlir/IR/Builders.h" // from @llvm-project
37 #include "mlir/IR/MLIRContext.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/Types.h" // from @llvm-project
40 #include "mlir/IR/Value.h" // from @llvm-project
41 #include "mlir/Pass/Pass.h" // from @llvm-project
42 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
43 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
44 #include "mlir/Support/LogicalResult.h" // from @llvm-project
45 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
46 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
48 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
49 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
52
53 namespace mlir {
54 namespace TFTPU {
55
56 namespace {
57
58 constexpr llvm::StringRef kDeviceAttr = "device";
59 constexpr llvm::StringRef kNameAttr = "name";
60 constexpr llvm::StringRef kNumCoresPerReplicaAttr = "num_cores_per_replica";
61 constexpr llvm::StringRef kNumReplicasAttr = "num_replicas";
62 constexpr llvm::StringRef kMirroredVariableIndicesAttr =
63 "_mirrored_variable_indices";
64 constexpr llvm::StringRef kNoReplicationCluster = "__no_replication_cluster";
65
66 constexpr llvm::StringRef kBadReplicateInfoAttrMsg =
67 "requires '_replication_info' string attribute";
68
69 // Mapping for `_replication_info` attribute to TPUReplicateMetadata attributes.
70 using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
71
72 // A set of operations. We use a `SmallSetVector` in order to have deterministic
73 // traversal order (= insertion order), independent of the pointer keys.
74 using OpSetVector = llvm::SmallSetVector<Operation*, 8>;
75
76 // Mapping for `_replication_info` attribute to ops of a cluster.
77 using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, OpSetVector, 8>;
78
79 struct TPUClusterFormationPass
80 : public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
getDependentDialectsmlir::TFTPU::__anon1e87b7740111::TPUClusterFormationPass81 void getDependentDialects(DialectRegistry& registry) const override {
82 registry.insert<tf_device::TensorFlowDeviceDialect>();
83 }
84
85 void runOnOperation() override;
86 };
87
88 // Creates a mapping from the TPUReplicateMetadata ops `_replication_info`
89 // attribute to its attributes and removes the ops. If multiple
90 // TPUReplicateMetadata ops have the same `_replication_info` attribute, an
91 // error will be returned.
CollectMetadata(Block * block,MetadataMap * metadata_map)92 LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
93 // Just look at top-level operations in the block (not nested ones)
94 for (Operation& op : llvm::make_early_inc_range(*block)) {
95 auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
96 if (!metadata_op) continue;
97
98 NamedAttrList attrs(metadata_op->getAttrDictionary());
99
100 // Missing or bad `_replication_info` attribute.
101 auto replication_info_attr = attrs.get(TF::kReplicationInfoAttr);
102 if (!replication_info_attr)
103 return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
104
105 auto replication_info_attr_str =
106 replication_info_attr.dyn_cast<StringAttr>();
107 if (!replication_info_attr_str ||
108 replication_info_attr_str.getValue().empty())
109 return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
110
111 // Remove `name` attribute.
112 attrs.erase(StringAttr::get(metadata_op.getContext(), kNameAttr));
113
114 auto it = metadata_map->try_emplace(replication_info_attr_str.getValue(),
115 std::move(attrs));
116
117 // There are multiple TPUReplicateMetadata ops with the same
118 // `_replication_info` attribute.
119 if (!it.second) {
120 return metadata_op.emitError()
121 << "multiple TPUReplicateMetadata ops with the same '"
122 << TF::kReplicationInfoAttr << "' attribute '"
123 << replication_info_attr_str.getValue() << "' found";
124 }
125 metadata_op.erase();
126 }
127 return success();
128 }
129
130 // Collects and clusters ops either based on `_replication_info` attribute
131 // (replicated case) or using one single cluster (non-replicated case). Also
132 // sets `device_type` if there is any cluster (note that the device type must be
133 // unique, otherwise we emit an error).
134 // Returns an error in case of invalid compilation or replication attribute(s).
CollectAndGroupClusterOps(Block * block,ClusterMap * clusters,std::string & device_type)135 LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters,
136 std::string& device_type) {
137 bool has_replicated_compiled_op = false;
138 bool has_non_replicated_compiled_op = false;
139 // Use ordered set here to make error message below deterministic.
140 std::set<llvm::StringRef> device_types;
141 for (Operation& op : *block) {
142 LogicalResult result = TF::HasValidCompilationAndReplicationAttributes(op);
143 if (failed(result)) return result;
144
145 // Collect device types which currently must be consistent per block
146 // (checked later).
147 auto device_type_attr =
148 op.getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr);
149 if (device_type_attr) device_types.insert(device_type_attr);
150
151 if (op.hasAttr(TF::kReplicationInfoAttr)) {
152 // For replicated case, borrow cluster structure from replication info.
153 // Following condition is already checked in
154 // `HasValidCompilationAndReplicationAttributes` above, assert here for
155 // documentation and to avoid breakage when that function is changed.
156 assert(op.hasAttr(TF::kCompileDeviceTypeAttr));
157 has_replicated_compiled_op = true;
158 auto attr = op.getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
159 auto it = clusters->try_emplace(attr.getValue());
160 it.first->getSecond().insert(&op);
161 } else if (op.hasAttr(TF::kCompileDeviceTypeAttr)) {
162 // For non-replicated case, assume one cluster per block (in line with
163 // Framework behavior).
164 has_non_replicated_compiled_op = true;
165 auto it = clusters->try_emplace(kNoReplicationCluster);
166 it.first->getSecond().insert(&op);
167 }
168 }
169 // Do some checks for unsupported cases.
170 if (has_replicated_compiled_op && has_non_replicated_compiled_op) {
171 return block->getParentOp()->emitError()
172 << "found mixed replicated and non-replicated compiled ops in same "
173 "block which is not supported";
174 }
175 if (device_types.size() > 1) {
176 return block->getParentOp()->emitError()
177 << "found different '" << TF::kCompileDeviceTypeAttr
178 << "' attribute values (" << llvm::join(device_types, ",")
179 << ") in same block which is not supported";
180 }
181 if (!clusters->empty()) {
182 // Note that for size < 1 we shouldn't have any cluster while for size > 1
183 // we should have returned with an error above.
184 assert(device_types.size() == 1);
185 device_type = device_types.begin()->str();
186 }
187 return success();
188 }
189
190 // Returns true iff `op` has a direct control dependency from (`incoming` ==
191 // true) or to (`incoming` == false) any op in `cluster_ops` or
192 // `cluster_dependent_ops`.
hasOpClusterControlDependency(Operation * op,bool incoming,const OpSetVector & cluster_ops,const OpSetVector & cluster_dependent_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)193 bool hasOpClusterControlDependency(
194 Operation* op, bool incoming, const OpSetVector& cluster_ops,
195 const OpSetVector& cluster_dependent_ops,
196 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
197 auto filter = [&](Operation* other_op) {
198 return cluster_ops.contains(other_op) ||
199 cluster_dependent_ops.contains(other_op);
200 };
201 return incoming ? !side_effect_analysis.DirectControlPredecessors(op, filter)
202 .empty()
203 : !side_effect_analysis.DirectControlSuccessors(op, filter)
204 .empty();
205 }
206
207 // Returns true iff `op` has a direct data dependency from (`incoming` == true
208 // or to (`incoming` == false) any op in `cluster_ops` or
209 // `cluster_dependent_ops`.
hasOpClusterDataDependency(Operation * op,bool incoming,const OpSetVector & cluster_ops,const OpSetVector & cluster_dependent_ops)210 bool hasOpClusterDataDependency(Operation* op, bool incoming,
211 const OpSetVector& cluster_ops,
212 const OpSetVector& cluster_dependent_ops) {
213 auto result = op->walk([&](Operation* inner_op) {
214 ValueRange values = incoming ? ValueRange(inner_op->getOperands())
215 : ValueRange(inner_op->getResults());
216 llvm::SmallVector<Operation*, 4> candidates;
217 for (Value value : values) {
218 if (incoming) {
219 candidates = {value.getDefiningOp()};
220 } else {
221 candidates.assign(value.getUsers().begin(), value.getUsers().end());
222 }
223 for (Operation* candidate_op : candidates) {
224 if (cluster_ops.contains(candidate_op) ||
225 cluster_dependent_ops.contains(candidate_op)) {
226 return WalkResult::interrupt();
227 }
228 }
229 }
230 return WalkResult::advance();
231 });
232 return result.wasInterrupted();
233 }
234
235 // Collects ops that need to be moved behind the cluster due to data or control
236 // dependencies.
CollectClusterSuccessorOps(Block * block,const OpSetVector & cluster_ops,const TF::SideEffectAnalysis::Info & side_effect_analysis)237 llvm::SmallSetVector<Operation*, 8> CollectClusterSuccessorOps(
238 Block* block, const OpSetVector& cluster_ops,
239 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
240 OpSetVector cluster_predecessor_ops;
241 OpSetVector cluster_successor_ops;
242
243 // Collect non-cluster ops that have a dependency to the cluster. For this
244 // traverse all ops from last to first cluster op and keep track of in-between
245 // non-cluster ops that have some outgoing (transitive) dependency to some
246 // cluster op (`cluster_predecessor_ops`).
247 auto rfront = Block::reverse_iterator(cluster_ops.front());
248 auto rback = Block::reverse_iterator(cluster_ops.back());
249 for (Operation& op : llvm::make_range(rback, rfront)) {
250 if (cluster_ops.contains(&op)) continue;
251 bool has_dependency_to_cluster =
252 hasOpClusterDataDependency(&op, /*incoming=*/false, cluster_ops,
253 cluster_predecessor_ops) ||
254 hasOpClusterControlDependency(&op, /*incoming=*/false, cluster_ops,
255 cluster_predecessor_ops,
256 side_effect_analysis);
257 if (has_dependency_to_cluster) cluster_predecessor_ops.insert(&op);
258 }
259 // Collect non-cluster ops that have a dependency from the cluster. For this
260 // traverse all ops from first to last cluster op and keep track of in-between
261 // non-cluster ops that have some incoming (transitive) dependency from some
262 // cluster op (`cluster_successor_ops`).
263 auto front = Block::iterator(cluster_ops.front());
264 auto back = Block::iterator(cluster_ops.back());
265 for (Operation& op : llvm::make_range(front, back)) {
266 if (cluster_ops.contains(&op)) continue;
267 bool has_dependency_from_cluster =
268 hasOpClusterDataDependency(&op, /*incoming=*/true, cluster_ops,
269 cluster_successor_ops) ||
270 hasOpClusterControlDependency(&op, /*incoming=*/true, cluster_ops,
271 cluster_successor_ops,
272 side_effect_analysis);
273 if (has_dependency_from_cluster) {
274 if (cluster_predecessor_ops.contains(&op)) {
275 // Op has a dependency from and to the cluster which is invalid. Instead
276 // of erroring out we don't add the op to `cluster_successor_ops` which
277 // is in line with previous behavior when certain control dependencies
278 // were not considered.
279 // TODO(b/216706460) Establish some contract here: Should we expect only
280 // valid clusters, or should we split clusters accordingly? The latter
281 // might have runtime impact for existing models.
282 // We should make this message an error once there is such a contract
283 // and once existing cases have been fixed.
284 op.emitWarning()
285 << "op has cyclic dependency with a compilation cluster";
286 } else {
287 cluster_successor_ops.insert(&op);
288 }
289 }
290 }
291 return cluster_successor_ops;
292 }
293
294 // Collects results and associated types of the cluster that are used outside of
295 // the cluster. These results and types are used to create the clusters
296 // `tf_device.cluster` and associated terminator. Results that have no uses
297 // outside of the cluster (i.e. results of ops in the cluster are only consumed
298 // by other ops in the cluster) are pruned.
CollectClusterResults(Block * block,const OpSetVector & cluster_ops)299 llvm::SmallVector<Value, 8> CollectClusterResults(
300 Block* block, const OpSetVector& cluster_ops) {
301 llvm::SmallVector<Value, 8> results;
302
303 for (Operation* op : cluster_ops) {
304 for (Value result : op->getResults()) {
305 for (Operation* user : result.getUsers()) {
306 // Check if user is not an op in the cluster.
307 if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
308 results.push_back(result);
309 break;
310 }
311 }
312 }
313 }
314
315 return results;
316 }
317
318 // Creates a `tf_device.cluster` to wrap cluster ops.
CreateClusterOp(Block * block,const OpSetVector & cluster_ops,llvm::ArrayRef<Value> results,llvm::ArrayRef<Operation * > cluster_successor_ops)319 tf_device::ClusterOp CreateClusterOp(
320 Block* block, const OpSetVector& cluster_ops, llvm::ArrayRef<Value> results,
321 llvm::ArrayRef<Operation*> cluster_successor_ops) {
322 // `tf_device.cluster` will be placed at where the last op of the cluster is.
323 Operation* last_cluster_op = cluster_ops.back();
324 OpBuilder builder(last_cluster_op);
325
326 llvm::SmallVector<Type, 8> result_types;
327 for (Value result : results) result_types.push_back(result.getType());
328 auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
329 result_types);
330
331 Block* body = new Block;
332 cluster.body().push_back(body);
333
334 // Move cluster ops to the cluster body. Also remove `_replication_info` and
335 // `device` attribute from ops in the cluster when that information is
336 // redundant will the `tf_device.cluster`. Do this for all ops including
337 // nested ops.
338 for (Operation* cluster_op : cluster_ops) {
339 cluster_op->moveBefore(body, body->end());
340 cluster_op->walk([&](Operation* inner_op) {
341 inner_op->removeAttr(TF::kReplicationInfoAttr);
342 inner_op->removeAttr(TF::kCompileDeviceTypeAttr);
343
344 if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
345 // Preserve device attribute if the op is placed on a replicated core
346 // device. Device attribute is used to infer the appropriate sharding
347 // within TPUs for this op.
348 // TODO(b/183598857): Use explicit sharding ops from the front-end.
349 // For example, dequeue ops generated by
350 // tensorflow/python/tpu/tpu_feed.py
351 if (!tensorflow::IsTPUReplicatedCore(attr.getValue())) {
352 inner_op->removeAttr(kDeviceAttr);
353 }
354 }
355 });
356 }
357
358 // Add terminator.
359 builder.setInsertionPointToEnd(body);
360 builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
361
362 // Replaces uses of cluster ops results outside of cluster with the associated
363 // `tf_device.cluster` results.
364 for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
365 Value old_ret = std::get<0>(ret_vals);
366 Value new_ret = std::get<1>(ret_vals);
367 for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
368 Operation* user = use.getOwner();
369 if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
370 }
371 }
372
373 // Move ops that depend on something in the cluster behind the cluster.
374 Operation* op_after_cluster = cluster.getOperation()->getNextNode();
375 for (Operation* op : cluster_successor_ops) op->moveBefore(op_after_cluster);
376 return cluster;
377 }
378
379 // Creates a `tf_device.replicate` to represent replication for the cluster, if
380 // necessary.
ReplicateCluster(tf_device::ClusterOp cluster,int num_replicas,int num_cores_per_replica)381 LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
382 int num_cores_per_replica) {
383 // No need to replicate.
384 if (num_replicas == 1) return success();
385
386 if (num_replicas < 1)
387 return cluster.emitError() << "requires '" << kNumReplicasAttr
388 << "' int attribute to be at least 1";
389
390 LogicalResult status = success();
391 // Collect all used TPUReplicatedInput ops.
392 llvm::SmallVector<Operation*, 8> replicated_input_ops;
393 mlir::visitUsedValuesDefinedAbove(
394 cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
395 Operation* def = operand->get().getDefiningOp();
396 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
397 replicated_input_ops.push_back(def);
398 // When model parallelism is used in conjunction with data parallelism
399 // for resource inputs, we need to collect the per replica resource
400 // inputs from input to `tf.TPUPartitionedInput` ops.
401 if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
402 if (pi->getNumOperands() != num_cores_per_replica)
403 status = pi.emitOpError()
404 << "requires " << num_cores_per_replica
405 << " operands but found " << pi->getNumOperands();
406 for (auto operand : pi.inputs()) {
407 if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
408 operand.getDefiningOp()))
409 replicated_input_ops.push_back(operand.getDefiningOp());
410 }
411 }
412 });
413
414 if (failed(status)) return failure();
415
416 // Indices of the replicate op's arguments that are mirrored variables.
417 llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
418
419 // Check if number of operands of each used TPUReplicatedInput op matches
420 // `num_replicas` or 1. Collect all their operands and associated type for
421 // creating the replicate op.
422 llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
423 llvm::SmallVector<Value, 8> packed_inputs;
424 llvm::SmallVector<Operation*, 8> replicated_ops;
425 llvm::SmallVector<Operation*, 8> packed_ops;
426 for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
427 auto input = pos_and_input.value();
428 bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
429 const int num_operands = input->getNumOperands();
430 int num_inputs = is_packed ? 1 : num_replicas;
431 if (num_operands != num_inputs)
432 return input->emitOpError() << "requires " << num_inputs << " operands";
433 if (is_packed) {
434 packed_inputs.push_back(input->getOperand(0));
435 packed_ops.push_back(input);
436 } else {
437 replicated_inputs.push_back(
438 {input->getOperands(), input->getOperand(0).getType()});
439 replicated_ops.push_back(input);
440 }
441 }
442
443 // Create `ordered_tpu_replicate_inputs` which constains the final ordered
444 // replicate inputs. All packed arguments are moved to the end of the arg
445 // list.
446 llvm::SmallVector<Operation*, 8> ordered_tpu_replicate_inputs =
447 replicated_ops;
448 ordered_tpu_replicate_inputs.append(packed_ops.begin(), packed_ops.end());
449
450 // Assign `mirrored_variable_indices` based on the ordered replicated inputs.
451 for (const auto& pos_and_input :
452 llvm::enumerate(ordered_tpu_replicate_inputs)) {
453 auto tpu_replicated_input =
454 llvm::cast<TF::TPUReplicatedInputOp>(pos_and_input.value());
455 if (tpu_replicated_input.is_mirrored_variable()) {
456 mirrored_variable_indices.push_back(pos_and_input.index());
457 }
458 }
459
460 // Create replicate op.
461 OpBuilder builder(cluster);
462 auto replicate_op = builder.create<tf_device::ReplicateOp>(
463 cluster.getLoc(), num_replicas,
464 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
465 replicated_inputs, packed_inputs, cluster.getResultTypes());
466
467 if (!mirrored_variable_indices.empty())
468 replicate_op->setAttr(kMirroredVariableIndicesAttr,
469 builder.getI64ArrayAttr(mirrored_variable_indices));
470
471 // Replace replicated cluster results with replicate op results.
472 for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
473 Value result = result_and_idx.value();
474 int idx = result_and_idx.index();
475 auto replicate_outputs = llvm::make_range(
476 std::next(replicate_op.result_begin(), idx * num_replicas),
477 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
478
479 for (auto& use : llvm::make_early_inc_range(result.getUses())) {
480 Operation* def = use.getOwner();
481 if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
482 // If user is not a `tf.TPUReplicatedOutput`, simply forward the first
483 // replica output. Certain Graphs under V1 create `tf.Identity` users of
484 // replicated ops to pin the TPU computation for execution.
485 use.set(*replicate_outputs.begin());
486 continue;
487 }
488
489 const int def_num_results = def->getNumResults();
490 if (def_num_results != num_replicas)
491 return def->emitOpError() << "requires " << num_replicas << " results";
492
493 def->replaceAllUsesWith(replicate_outputs);
494 }
495 }
496
497 // Collect all `tf.TPUPartitionedInput` ops to be moved inside the
498 // `tf_device.replicate` later.
499 llvm::SmallSet<Operation*, 4> partitioned_inputs;
500 for (auto input_and_block_arg :
501 llvm::zip(ordered_tpu_replicate_inputs,
502 replicate_op.GetBody().getArguments())) {
503 Operation* input = std::get<0>(input_and_block_arg);
504 Value block_arg = std::get<1>(input_and_block_arg);
505 mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
506 cluster.body());
507 // Update replicated input use in tf.TPUPartitionedInput op.
508 for (auto& use : input->getUses()) {
509 auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
510 if (pi) {
511 pi.setOperand(use.getOperandNumber(), block_arg);
512 partitioned_inputs.insert(pi.getOperation());
513 }
514 }
515 }
516
517 // Create terminator for replicate op and move `tf_device.cluster` and
518 // `tf.TPUPartitionedInput`(s) into replicate body.
519 builder.setInsertionPointToEnd(&replicate_op.GetBody());
520 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
521 cluster.getResults());
522 for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
523
524 cluster.getOperation()->moveBefore(return_op);
525
526 return success();
527 }
528
SetNoReplicationClusterAttrs(tf_device::ClusterOp cluster,llvm::StringRef device_type)529 void SetNoReplicationClusterAttrs(tf_device::ClusterOp cluster,
530 llvm::StringRef device_type) {
531 OpBuilder builder(cluster);
532 cluster->setAttr(TF::kReplicationInfoAttr,
533 builder.getStringAttr(kNoReplicationCluster));
534 cluster->setAttr(TF::kCompileDeviceTypeAttr,
535 builder.getStringAttr(device_type));
536
537 // TODO(b/229992058) Propagate `allow_soft_placement` (and other attributes?)
538 // instead of hard-coding.
539 cluster->setAttr("allow_soft_placement", builder.getBoolAttr(true));
540 cluster->setAttr("topology", builder.getStringAttr(""));
541 cluster->setAttr("num_cores_per_replica",
542 builder.getIntegerAttr(builder.getI32Type(), 1));
543 cluster->setAttr("device_assignment", builder.getArrayAttr({}));
544 cluster->setAttr("use_spmd_for_xla_partitioning", builder.getBoolAttr(false));
545 cluster->setAttr("step_marker_location", builder.getStringAttr(""));
546 }
547
548 // Forms compilation clusters in `block`. If the block contains a
549 // `TPUReplicateMetadata` op, then we form clusters according to
550 // `_replication_info` values (ops with same value go to same cluster).
551 // Otherwise, in the non-replicated case, we build one compilation cluster per
552 // block.
553 //
554 // We do this in following steps:
555 // 1. Find `TPUReplicateMetadata` op in `block` (might not exist).
556 // 2. Collect and group cluster ops (either based on `_replication_info`
557 // attributes or forming one single cluster).
558 // 3. Find external uses of cluster ops.
559 // 4. Create `tf_device.cluster` with results consisting of the external uses
560 // of cluster ops determined at 3.
561 // 5. Move cluster ops to `tf_device.cluster` body.
562 // 6. Replace external uses of cluster ops uses with `tf_device.cluster`
563 // results.
564 // 7. Move users from 2 to after the `tf_device.cluster`.
565 // 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
566 // attribute `num_replicas` is greater than 1.
567 // 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
FormClustersInBlock(Block * block,const TF::SideEffectAnalysis::Info & side_effect_analysis)568 LogicalResult FormClustersInBlock(
569 Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
570 MetadataMap metadata_map;
571 LogicalResult result = CollectMetadata(block, &metadata_map);
572 if (failed(result)) return result;
573
574 // If there is no TPUReplicateMetadata op in this block, process blocks in
575 // regions attached to the op's in the block.
576 if (metadata_map.empty()) {
577 for (Operation& op : *block) {
578 for (Region& region : op.getRegions()) {
579 if (!llvm::hasSingleElement(region))
580 return op.emitOpError("Expected single block region");
581 if (failed(FormClustersInBlock(®ion.front(), side_effect_analysis)))
582 return failure();
583 }
584 }
585 }
586
587 ClusterMap clusters;
588 std::string device_type;
589 result = CollectAndGroupClusterOps(block, &clusters, device_type);
590 if (failed(result)) return result;
591
592 for (const auto& cluster_metadata_and_ops : clusters) {
593 const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
594
595 bool has_replication =
596 cluster_metadata_and_ops.getFirst() != kNoReplicationCluster;
597 auto cluster_metadata =
598 metadata_map.find(cluster_metadata_and_ops.getFirst());
599
600 // No TPUReplicateMetadata for a `_replication_info` attribute.
601 if (has_replication && cluster_metadata == metadata_map.end()) {
602 block->getParentOp()->emitWarning()
603 << "TPUReplicateMetadata for associated '" << TF::kReplicationInfoAttr
604 << "' attribute '" << cluster_metadata_and_ops.getFirst()
605 << "' is missing";
606 continue;
607 }
608
609 OpSetVector cluster_successor_ops =
610 CollectClusterSuccessorOps(block, cluster_ops, side_effect_analysis);
611
612 llvm::SmallVector<Value, 8> results =
613 CollectClusterResults(block, cluster_ops);
614
615 tf_device::ClusterOp cluster = CreateClusterOp(
616 block, cluster_ops, results, cluster_successor_ops.getArrayRef());
617
618 if (!has_replication) {
619 SetNoReplicationClusterAttrs(cluster, device_type);
620 continue;
621 }
622 // Determine `num_replicas`.
623 auto num_replicas_attr =
624 cluster_metadata->getSecond().get(kNumReplicasAttr);
625 if (!num_replicas_attr || !num_replicas_attr.isa<mlir::IntegerAttr>())
626 return cluster.emitError()
627 << "requires '" << kNumReplicasAttr << "' int attribute";
628 int num_replicas = num_replicas_attr.cast<mlir::IntegerAttr>().getInt();
629
630 // Determine `num_cores_per_replica`.
631 int num_cores_per_replica = 1;
632 auto num_cores_per_replica_attr =
633 cluster_metadata->getSecond()
634 .get(kNumCoresPerReplicaAttr)
635 .dyn_cast_or_null<mlir::IntegerAttr>();
636 if (num_cores_per_replica_attr)
637 num_cores_per_replica = num_cores_per_replica_attr.getInt();
638 if (failed(ReplicateCluster(cluster, num_replicas, num_cores_per_replica)))
639 return failure();
640
641 // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
642 cluster->setAttrs(
643 cluster_metadata->second.getDictionary(cluster.getContext()));
644 // Exclude `num_replicas` as cluster should be replicated if necessary.
645 cluster->removeAttr(kNumReplicasAttr);
646 }
647
648 return success();
649 }
650
FormClustersInFunction(func::FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)651 LogicalResult FormClustersInFunction(
652 func::FuncOp func,
653 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
654 if (!llvm::hasSingleElement(func))
655 return func.emitOpError("Expecting a single block function");
656
657 if (failed(FormClustersInBlock(&func.front(), side_effect_analysis)))
658 return failure();
659
660 // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
661 auto remove_result = func.walk([&](Operation* op) {
662 if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
663 return WalkResult::advance();
664
665 // Forward operand to result. When `num_replicas` attribute is 1, no
666 // `tf_device.replicate` is created and replicated (1) operands/results are
667 // untouched.
668 if (op->getNumOperands() == 1 && op->getNumResults() == 1)
669 op->getResult(0).replaceAllUsesWith(op->getOperand(0));
670
671 // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
672 // `num_replicas` to 1.
673 if (!op->use_empty()) {
674 op->emitOpError() << "is expected to have no uses, but it is operand#"
675 << op->use_begin()->getOperandNumber() << " of "
676 << *op->use_begin()->getOwner();
677 return WalkResult::interrupt();
678 }
679
680 op->erase();
681
682 return WalkResult::advance();
683 });
684
685 return failure(remove_result.wasInterrupted());
686 }
687
runOnOperation()688 void TPUClusterFormationPass::runOnOperation() {
689 // Attributes on tf.Constant aren't reliable: CSE will merge ConstantLike ops
690 // with the same value (but different attributes!) into the same tf.Const
691 // definition, potentially leading to bogus _replication_info attributes. So
692 // we just scrub all tf.Constants of all extra attributes.
693 // TODO(kramm): Remove this once tf.Const's folder is aware of extra
694 // attributes.
695 auto value_str_attr = StringAttr::get(&getContext(), "value");
696 getOperation().walk([&](TF::ConstOp cst) {
697 auto dict = cst->getAttrDictionary();
698 if (dict.size() == 1) {
699 return; // Optimization. Assume the one attribute is "value".
700 }
701 // Recreate the attributes dictionary to only contain "value".
702 NamedAttrList attributes;
703 attributes.append(NamedAttribute(value_str_attr, cst->getAttr("value")));
704 cst->setAttrs(attributes.getDictionary(&getContext()));
705 });
706
707 auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
708 for (auto func : getOperation().getOps<func::FuncOp>())
709 if (!func.isExternal() &&
710 failed(FormClustersInFunction(
711 func, side_effect_analysis.GetAnalysisForFunc(func))))
712 return signalPassFailure();
713 }
714 } // anonymous namespace
715
CreateTPUClusterFormationPass()716 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
717 return std::make_unique<TPUClusterFormationPass>();
718 }
719
720 } // namespace TFTPU
721 } // namespace mlir
722