1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SetVector.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/Block.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/Operation.h" // from @llvm-project
22 #include "mlir/IR/Value.h" // from @llvm-project
23 #include "mlir/IR/Visitors.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30
31 namespace mlir {
32 namespace TFTPU {
33
34 namespace {
35
36 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
37
HasOutsideCompilationAttribute(Operation * op)38 bool HasOutsideCompilationAttribute(Operation* op) {
39 return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
40 }
41
42 // Finds op that created a given value. If the value is a BlockArgument, this
43 // returns the owner of the Block.
GetOpOfValue(Value value)44 Operation* GetOpOfValue(Value value) {
45 if (auto block_arg = value.dyn_cast<BlockArgument>())
46 return block_arg.getOwner()->getParentOp();
47
48 return value.getDefiningOp();
49 }
50
51 // TODO(b/158596585): Replace this with a cost model analysis.
IsTrivialUnaryOperation(Operation * op)52 bool IsTrivialUnaryOperation(Operation* op) {
53 return llvm::isa<TF::CastOp, TF::IdentityOp>(op);
54 }
55
56 // Adds outside compilation attributes to unary ops such as Identity/Cast ops
57 // at the head of TPU computation that is used only by other outside compiled
58 // ops. Identity ops and Cast ops is commonly added to the start of TPU
59 // computation. Adding/expanding outside compilation attributes to these ops
60 // will ensure that head outside compiled ops are correctly located and moved to
61 // host.
62 // TODO(b/158691733): Also handle ops inside function calls/control flows.
ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,OpBuilder * builder)63 void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
64 OpBuilder* builder) {
65 Region* cluster_region = &cluster.body();
66 llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
67
68 // Traverse the graph in topological order to find all outside compiled ops
69 // at head of TPU computation or unary ops that are only used by other outside
70 // compiled ops.
71 auto cluster_ops = cluster.GetBody().without_terminator();
72 for (Operation& cluster_op : cluster_ops) {
73 if (IsTrivialUnaryOperation(&cluster_op) ||
74 HasOutsideCompilationAttribute(&cluster_op)) {
75 auto walk_result = cluster_op.walk([&](Operation* op) {
76 for (Value operand : op->getOperands()) {
77 Operation* operand_op = GetOpOfValue(operand);
78 if (head_outside_compiled_ops.count(operand_op)) continue;
79
80 if (operand_op->getParentRegion() == cluster_region)
81 return WalkResult::interrupt();
82 }
83 return WalkResult::advance();
84 });
85
86 if (!walk_result.wasInterrupted())
87 head_outside_compiled_ops.insert(&cluster_op);
88 }
89 }
90
91 for (auto head_outside_compiled_op :
92 llvm::reverse(head_outside_compiled_ops)) {
93 auto users = head_outside_compiled_op->getUsers();
94 if (users.empty() ||
95 HasOutsideCompilationAttribute(head_outside_compiled_op))
96 continue;
97
98 bool should_expand_op_to_host_computation = true;
99 for (auto consumer_op : users) {
100 if (should_expand_op_to_host_computation &&
101 !HasOutsideCompilationAttribute(consumer_op)) {
102 should_expand_op_to_host_computation = false;
103 continue;
104 }
105 }
106
107 if (should_expand_op_to_host_computation)
108 head_outside_compiled_op->setAttr(kXlaOutsideCompilationAttr,
109 builder->getStringAttr(""));
110 }
111 }
112
113 struct TPUHostComputationExpansionPass
114 : public TF::TPUHostComputationExpansionPassBase<
115 TPUHostComputationExpansionPass> {
116 void runOnOperation() override;
117 };
118
runOnOperation()119 void TPUHostComputationExpansionPass::runOnOperation() {
120 OpBuilder builder(&getContext());
121 getOperation().walk([&](tf_device::ClusterOp cluster) {
122 ExpandHeadOutsideCompiledOps(cluster, &builder);
123 });
124 }
125
126 } // anonymous namespace
127
128 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUHostComputationExpansionPass()129 CreateTPUHostComputationExpansionPass() {
130 return std::make_unique<TPUHostComputationExpansionPass>();
131 }
132
133 } // namespace TFTPU
134 } // namespace mlir
135