1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/Pass/PassManager.h" // from @llvm-project
18 #include "mlir/Transforms/Passes.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
21
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25
26 // A special DenseMapInfo that hashes only operands of a operation, and treats
27 // two operations equivalent if their operands are the same.
28 struct OpWithSameArgsInfo : llvm::DenseMapInfo<mlir::Operation *> {
getHashValuetensorflow::tfrt_compiler::__anon376b11610111::OpWithSameArgsInfo29 static unsigned getHashValue(const mlir::Operation *const_op) {
30 auto *op = const_cast<mlir::Operation *>(const_op);
31 return llvm::hash_combine(
32 llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
33 }
34
isEqualtensorflow::tfrt_compiler::__anon376b11610111::OpWithSameArgsInfo35 static bool isEqual(const mlir::Operation *const_lhs,
36 const mlir::Operation *const_rhs) {
37 auto *lhs = const_cast<mlir::Operation *>(const_lhs);
38 auto *rhs = const_cast<mlir::Operation *>(const_rhs);
39 if (lhs == rhs) return true;
40 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
41 rhs == getTombstoneKey() || rhs == getEmptyKey())
42 return false;
43
44 return std::equal(lhs->operand_begin(), lhs->operand_end(),
45 rhs->operand_begin(), rhs->operand_end());
46 }
47 };
48
49 // This pass merges non-side-effecting tf.If ops if their operands are the same.
50 // For example,
51 // %r0 = tf.If(%cond, %x) {else = @else_0, then = @then_0}
52 // %r1, %r2 = tf.If(%cond, %x) {else = @else_1, then = @then_1}
53 //
54 // will be converted to:
55 // func private @merge_else(%arg) {
56 // %r0 = tf.PartitionedCall(%arg) {f = @else_0}
57 // %r1, %r2 = tf.PartitionedCall(%arg) {f = @else_1}
58 // return %r0, %r1, %r2
59 // }
60 // func private @merge_then(%arg) {
61 // %r0 = tf.PartitionedCall(%arg) {f = @then_0}
62 // %r1, %r2 = tf.PartitionedCall(%arg) {f = @then_1}
63 // return %r0, %r1, %r2
64 // }
65 //
66 // %r0, %r1, %r2 = tf.If(%cond, %arg) {else = @merge_else, then = @merge_then}
67 //
68 // Then the inliner pass is run on the module, so the bodies of else_0 and
69 // else_1 are inlined into the body of merge_else, and the bodies of then_0 and
70 // then_1 are inlined into the body of merge_then.
71 //
72 // Note that the results will be concatenated.
73 class MergeTfIfOpsPass
74 : public mlir::PassWrapper<MergeTfIfOpsPass,
75 mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const76 llvm::StringRef getArgument() const final { return "tfrt-merge-tf-if-ops"; }
getDescription() const77 llvm::StringRef getDescription() const final {
78 return "Merge stateless tf.If ops with the same arguments.";
79 }
80
runOnOperation()81 void runOnOperation() override {
82 constexpr int kMaxIter = 10;
83 auto module = getOperation();
84
85 bool changed = true;
86 for (int i = 0; i < kMaxIter && changed; ++i) {
87 changed = false;
88 for (auto func_op :
89 llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
90 changed |= ProcessFunction(func_op, i);
91 }
92
93 if (changed) {
94 // Run inliner pass to expose more merge opportunities among the
95 // then-branch functions and the else-branch functions that are now
96 // respectively merged, for the next iteration.
97 mlir::OpPassManager pm(module.getOperationName());
98 pm.addPass(mlir::createInlinerPass());
99 if (mlir::failed(runPipeline(pm, module))) {
100 module.emitWarning(
101 absl::StrCat("could not run inliner pass within the "
102 "tfrt-merge-tf-if-ops pass iteration ",
103 i));
104 break;
105 }
106 }
107 }
108 }
109
ProcessFunction(mlir::func::FuncOp op,int iteration)110 bool ProcessFunction(mlir::func::FuncOp op, int iteration) {
111 // Use a hash map to group tf.If ops with the same operands.
112 llvm::SmallDenseMap<mlir::Operation *, llvm::SmallVector<mlir::TF::IfOp, 2>,
113 2, OpWithSameArgsInfo>
114 if_ops_to_merge;
115
116 for (mlir::Operation &op : op.front()) {
117 auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
118
119 // Skip non tf.If ops and tf.If ops that are side-effecting.
120 if (!if_op || !if_op.is_stateless()) continue;
121
122 if_ops_to_merge[if_op].push_back(if_op);
123 }
124
125 int id = 0;
126
127 // Set the insertion point to the current function, as we will insert new
128 // functions here.
129 mlir::OpBuilder builder(op);
130
131 // Track the tf.If ops that should be removed as they are merged.
132 llvm::SmallVector<mlir::TF::IfOp, 4> if_ops_to_remove;
133
134 bool changed = false;
135 for (auto &iter : if_ops_to_merge) {
136 if (iter.second.size() <= 1) continue;
137
138 // Merge tf.If ops that have the same operands. The merged branches will
139 // be given unique names.
140 MergeIfOpsWithSameArgs(builder, iter.first->getLoc(),
141 /*branch_prefix=*/
142 absl::StrCat(op.getSymName().str(), "_merged_if_",
143 iteration, "_", id++),
144 iter.second);
145
146 if_ops_to_remove.append(iter.second.begin(), iter.second.end());
147 changed = true;
148 }
149
150 // Now that we are no longer using `if_ops_to_merge` or any other data
151 // structures that uses the operations that will be removed, we can now
152 // erase these if ops.
153 for (auto op : if_ops_to_remove) op->erase();
154
155 return changed;
156 }
157
MergeIfOpsWithSameArgs(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,llvm::MutableArrayRef<mlir::TF::IfOp> if_ops)158 void MergeIfOpsWithSameArgs(mlir::OpBuilder &builder, mlir::Location loc,
159 absl::string_view branch_prefix,
160 llvm::MutableArrayRef<mlir::TF::IfOp> if_ops) {
161 assert(if_ops.size() > 1);
162
163 // The results of the merged tf.If op are the concatenation of results of
164 // the original tf.If ops.
165 llvm::SmallVector<mlir::Type, 4> new_result_types;
166 for (auto if_op : if_ops) {
167 new_result_types.append(if_op->result_type_begin(),
168 if_op->result_type_end());
169 }
170
171 auto branch_function_type = builder.getFunctionType(
172 if_ops.front().input().getTypes(), new_result_types);
173
174 // Create new branches for the merged tf.If op.
175 auto then_branch_name = CreateBranchFunction(
176 builder, loc, branch_prefix,
177 /*branch_suffix=*/"_then", branch_function_type, if_ops,
178 [](mlir::TF::IfOp op) { return op.then_branchAttr(); });
179
180 auto else_branch_name = CreateBranchFunction(
181 builder, loc, branch_prefix,
182 /*branch_suffix=*/"_else", branch_function_type, if_ops,
183 [](mlir::TF::IfOp op) { return op.else_branchAttr(); });
184
185 mlir::OpBuilder::InsertionGuard guard(builder);
186 builder.setInsertionPoint(if_ops.front());
187
188 // Create the merged tf.If op using the new branches.
189 auto new_if_op = builder.create<mlir::TF::IfOp>(
190 loc, new_result_types, if_ops.front().cond(), if_ops.front().input(),
191 then_branch_name, else_branch_name, /*is_stateless=*/true);
192
193 // Replace the uses of results of the original tf.If ops with the results of
194 // the merged tf.If op.
195 auto new_result_iter = new_if_op.output().begin();
196 for (auto if_op : if_ops) {
197 for (auto result : if_op.output()) {
198 assert(new_result_iter != new_if_op.output().end());
199 result.replaceAllUsesWith(*new_result_iter);
200 ++new_result_iter;
201 }
202 }
203 }
204
CreateBranchFunction(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,absl::string_view branch_suffix,mlir::FunctionType branch_function_type,llvm::ArrayRef<mlir::TF::IfOp> if_ops,llvm::function_ref<mlir::FlatSymbolRefAttr (mlir::TF::IfOp)> get_branch)205 llvm::StringRef CreateBranchFunction(
206 mlir::OpBuilder &builder, mlir::Location loc,
207 absl::string_view branch_prefix, absl::string_view branch_suffix,
208 mlir::FunctionType branch_function_type,
209 llvm::ArrayRef<mlir::TF::IfOp> if_ops,
210 llvm::function_ref<mlir::FlatSymbolRefAttr(mlir::TF::IfOp)> get_branch) {
211 std::string branch_name = absl::StrCat(branch_prefix, branch_suffix);
212 auto branch = builder.create<mlir::func::FuncOp>(loc, branch_name,
213 branch_function_type);
214 branch.setVisibility(mlir::func::FuncOp::Visibility::Private);
215
216 mlir::OpBuilder::InsertionGuard guard(builder);
217
218 // In the body of newly created branch function, we insert
219 // tf.PartitionedCall ops to call the original branches.
220 auto *block = branch.addEntryBlock();
221 builder.setInsertionPointToStart(block);
222 auto empty_string_attr = builder.getStringAttr("");
223
224 llvm::SmallVector<mlir::Value, 4> results;
225 results.reserve(branch_function_type.getNumResults());
226
227 for (auto if_op : if_ops) {
228 // Create the call op to the original branch. The arguments are simply
229 // the arguments from the wrapper function.
230 auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
231 if_op.getLoc(), if_op.getResultTypes(), block->getArguments(),
232 get_branch(if_op), empty_string_attr, empty_string_attr,
233 empty_string_attr);
234
235 // The results are the concatenation of the original branches.
236 results.append(call_op.output().begin(), call_op.output().end());
237 }
238
239 builder.create<mlir::func::ReturnOp>(loc, results);
240
241 return branch.getSymName();
242 }
243
244 public:
245 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeTfIfOpsPass)
246 };
247
248 } // namespace
249
CreateMergeTfIfOpsPass()250 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass() {
251 return std::make_unique<MergeTfIfOpsPass>();
252 }
253
254 static mlir::PassRegistration<MergeTfIfOpsPass> register_pass(
255 CreateMergeTfIfOpsPass);
256
257 } // namespace tfrt_compiler
258 } // namespace tensorflow
259