xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
17 #include "mlir/Pass/PassManager.h"  // from @llvm-project
18 #include "mlir/Transforms/Passes.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
21 
22 namespace tensorflow {
23 namespace tfrt_compiler {
24 namespace {
25 
26 // A special DenseMapInfo that hashes only operands of a operation, and treats
27 // two operations equivalent if their operands are the same.
28 struct OpWithSameArgsInfo : llvm::DenseMapInfo<mlir::Operation *> {
getHashValuetensorflow::tfrt_compiler::__anon376b11610111::OpWithSameArgsInfo29   static unsigned getHashValue(const mlir::Operation *const_op) {
30     auto *op = const_cast<mlir::Operation *>(const_op);
31     return llvm::hash_combine(
32         llvm::hash_combine_range(op->operand_begin(), op->operand_end()));
33   }
34 
isEqualtensorflow::tfrt_compiler::__anon376b11610111::OpWithSameArgsInfo35   static bool isEqual(const mlir::Operation *const_lhs,
36                       const mlir::Operation *const_rhs) {
37     auto *lhs = const_cast<mlir::Operation *>(const_lhs);
38     auto *rhs = const_cast<mlir::Operation *>(const_rhs);
39     if (lhs == rhs) return true;
40     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
41         rhs == getTombstoneKey() || rhs == getEmptyKey())
42       return false;
43 
44     return std::equal(lhs->operand_begin(), lhs->operand_end(),
45                       rhs->operand_begin(), rhs->operand_end());
46   }
47 };
48 
49 // This pass merges non-side-effecting tf.If ops if their operands are the same.
50 // For example,
51 //  %r0 = tf.If(%cond, %x) {else = @else_0, then = @then_0}
52 //  %r1, %r2 = tf.If(%cond, %x) {else = @else_1, then = @then_1}
53 //
54 // will be converted to:
55 //  func private @merge_else(%arg) {
56 //    %r0 = tf.PartitionedCall(%arg) {f = @else_0}
57 //    %r1, %r2 = tf.PartitionedCall(%arg) {f = @else_1}
58 //    return %r0, %r1, %r2
59 //  }
60 //  func private @merge_then(%arg) {
61 //    %r0 = tf.PartitionedCall(%arg) {f = @then_0}
62 //    %r1, %r2 = tf.PartitionedCall(%arg) {f = @then_1}
63 //    return %r0, %r1, %r2
64 //  }
65 //
66 //  %r0, %r1, %r2 = tf.If(%cond, %arg) {else = @merge_else, then = @merge_then}
67 //
68 // Then the inliner pass is run on the module, so the bodies of else_0 and
69 // else_1 are inlined into the body of merge_else, and the bodies of then_0 and
70 // then_1 are inlined into the body of merge_then.
71 //
72 // Note that the results will be concatenated.
73 class MergeTfIfOpsPass
74     : public mlir::PassWrapper<MergeTfIfOpsPass,
75                                mlir::OperationPass<mlir::ModuleOp>> {
getArgument() const76   llvm::StringRef getArgument() const final { return "tfrt-merge-tf-if-ops"; }
getDescription() const77   llvm::StringRef getDescription() const final {
78     return "Merge stateless tf.If ops with the same arguments.";
79   }
80 
runOnOperation()81   void runOnOperation() override {
82     constexpr int kMaxIter = 10;
83     auto module = getOperation();
84 
85     bool changed = true;
86     for (int i = 0; i < kMaxIter && changed; ++i) {
87       changed = false;
88       for (auto func_op :
89            llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
90         changed |= ProcessFunction(func_op, i);
91       }
92 
93       if (changed) {
94         // Run inliner pass to expose more merge opportunities among the
95         // then-branch functions and the else-branch functions that are now
96         // respectively merged, for the next iteration.
97         mlir::OpPassManager pm(module.getOperationName());
98         pm.addPass(mlir::createInlinerPass());
99         if (mlir::failed(runPipeline(pm, module))) {
100           module.emitWarning(
101               absl::StrCat("could not run inliner pass within the "
102                            "tfrt-merge-tf-if-ops pass iteration ",
103                            i));
104           break;
105         }
106       }
107     }
108   }
109 
ProcessFunction(mlir::func::FuncOp op,int iteration)110   bool ProcessFunction(mlir::func::FuncOp op, int iteration) {
111     // Use a hash map to group tf.If ops with the same operands.
112     llvm::SmallDenseMap<mlir::Operation *, llvm::SmallVector<mlir::TF::IfOp, 2>,
113                         2, OpWithSameArgsInfo>
114         if_ops_to_merge;
115 
116     for (mlir::Operation &op : op.front()) {
117       auto if_op = llvm::dyn_cast<mlir::TF::IfOp>(&op);
118 
119       // Skip non tf.If ops and tf.If ops that are side-effecting.
120       if (!if_op || !if_op.is_stateless()) continue;
121 
122       if_ops_to_merge[if_op].push_back(if_op);
123     }
124 
125     int id = 0;
126 
127     // Set the insertion point to the current function, as we will insert new
128     // functions here.
129     mlir::OpBuilder builder(op);
130 
131     // Track the tf.If ops that should be removed as they are merged.
132     llvm::SmallVector<mlir::TF::IfOp, 4> if_ops_to_remove;
133 
134     bool changed = false;
135     for (auto &iter : if_ops_to_merge) {
136       if (iter.second.size() <= 1) continue;
137 
138       // Merge tf.If ops that have the same operands. The merged branches will
139       // be given unique names.
140       MergeIfOpsWithSameArgs(builder, iter.first->getLoc(),
141                              /*branch_prefix=*/
142                              absl::StrCat(op.getSymName().str(), "_merged_if_",
143                                           iteration, "_", id++),
144                              iter.second);
145 
146       if_ops_to_remove.append(iter.second.begin(), iter.second.end());
147       changed = true;
148     }
149 
150     // Now that we are no longer using `if_ops_to_merge` or any other data
151     // structures that uses the operations that will be removed, we can now
152     // erase these if ops.
153     for (auto op : if_ops_to_remove) op->erase();
154 
155     return changed;
156   }
157 
MergeIfOpsWithSameArgs(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,llvm::MutableArrayRef<mlir::TF::IfOp> if_ops)158   void MergeIfOpsWithSameArgs(mlir::OpBuilder &builder, mlir::Location loc,
159                               absl::string_view branch_prefix,
160                               llvm::MutableArrayRef<mlir::TF::IfOp> if_ops) {
161     assert(if_ops.size() > 1);
162 
163     // The results of the merged tf.If op are the concatenation of results of
164     // the original tf.If ops.
165     llvm::SmallVector<mlir::Type, 4> new_result_types;
166     for (auto if_op : if_ops) {
167       new_result_types.append(if_op->result_type_begin(),
168                               if_op->result_type_end());
169     }
170 
171     auto branch_function_type = builder.getFunctionType(
172         if_ops.front().input().getTypes(), new_result_types);
173 
174     // Create new branches for the merged tf.If op.
175     auto then_branch_name = CreateBranchFunction(
176         builder, loc, branch_prefix,
177         /*branch_suffix=*/"_then", branch_function_type, if_ops,
178         [](mlir::TF::IfOp op) { return op.then_branchAttr(); });
179 
180     auto else_branch_name = CreateBranchFunction(
181         builder, loc, branch_prefix,
182         /*branch_suffix=*/"_else", branch_function_type, if_ops,
183         [](mlir::TF::IfOp op) { return op.else_branchAttr(); });
184 
185     mlir::OpBuilder::InsertionGuard guard(builder);
186     builder.setInsertionPoint(if_ops.front());
187 
188     // Create the merged tf.If op using the new branches.
189     auto new_if_op = builder.create<mlir::TF::IfOp>(
190         loc, new_result_types, if_ops.front().cond(), if_ops.front().input(),
191         then_branch_name, else_branch_name, /*is_stateless=*/true);
192 
193     // Replace the uses of results of the original tf.If ops with the results of
194     // the merged tf.If op.
195     auto new_result_iter = new_if_op.output().begin();
196     for (auto if_op : if_ops) {
197       for (auto result : if_op.output()) {
198         assert(new_result_iter != new_if_op.output().end());
199         result.replaceAllUsesWith(*new_result_iter);
200         ++new_result_iter;
201       }
202     }
203   }
204 
CreateBranchFunction(mlir::OpBuilder & builder,mlir::Location loc,absl::string_view branch_prefix,absl::string_view branch_suffix,mlir::FunctionType branch_function_type,llvm::ArrayRef<mlir::TF::IfOp> if_ops,llvm::function_ref<mlir::FlatSymbolRefAttr (mlir::TF::IfOp)> get_branch)205   llvm::StringRef CreateBranchFunction(
206       mlir::OpBuilder &builder, mlir::Location loc,
207       absl::string_view branch_prefix, absl::string_view branch_suffix,
208       mlir::FunctionType branch_function_type,
209       llvm::ArrayRef<mlir::TF::IfOp> if_ops,
210       llvm::function_ref<mlir::FlatSymbolRefAttr(mlir::TF::IfOp)> get_branch) {
211     std::string branch_name = absl::StrCat(branch_prefix, branch_suffix);
212     auto branch = builder.create<mlir::func::FuncOp>(loc, branch_name,
213                                                      branch_function_type);
214     branch.setVisibility(mlir::func::FuncOp::Visibility::Private);
215 
216     mlir::OpBuilder::InsertionGuard guard(builder);
217 
218     // In the body of newly created branch function, we insert
219     // tf.PartitionedCall ops to call the original branches.
220     auto *block = branch.addEntryBlock();
221     builder.setInsertionPointToStart(block);
222     auto empty_string_attr = builder.getStringAttr("");
223 
224     llvm::SmallVector<mlir::Value, 4> results;
225     results.reserve(branch_function_type.getNumResults());
226 
227     for (auto if_op : if_ops) {
228       // Create the call op to the original branch. The arguments are simply
229       // the arguments from the wrapper function.
230       auto call_op = builder.create<mlir::TF::PartitionedCallOp>(
231           if_op.getLoc(), if_op.getResultTypes(), block->getArguments(),
232           get_branch(if_op), empty_string_attr, empty_string_attr,
233           empty_string_attr);
234 
235       // The results are the concatenation of the original branches.
236       results.append(call_op.output().begin(), call_op.output().end());
237     }
238 
239     builder.create<mlir::func::ReturnOp>(loc, results);
240 
241     return branch.getSymName();
242   }
243 
244  public:
245   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeTfIfOpsPass)
246 };
247 
248 }  // namespace
249 
CreateMergeTfIfOpsPass()250 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass() {
251   return std::make_unique<MergeTfIfOpsPass>();
252 }
253 
254 static mlir::PassRegistration<MergeTfIfOpsPass> register_pass(
255     CreateMergeTfIfOpsPass);
256 
257 }  // namespace tfrt_compiler
258 }  // namespace tensorflow
259