1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/Diagnostics.h" // from @llvm-project
28 #include "mlir/IR/Operation.h" // from @llvm-project
29 #include "mlir/IR/OperationSupport.h" // from @llvm-project
30 #include "mlir/IR/SymbolTable.h" // from @llvm-project
31 #include "mlir/Support/LLVM.h" // from @llvm-project
32 #include "mlir/Support/LogicalResult.h" // from @llvm-project
33 #include "mlir/Transforms/Passes.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36
37 namespace tensorflow {
38 namespace tfrt_compiler {
39 namespace {
40
41 using ::mlir::ArrayRef;
42 using ::mlir::ModuleOp;
43 using ::mlir::Operation;
44 using ::mlir::SymbolTable;
45 using ::mlir::SymbolTableCollection;
46 using ::mlir::SymbolUserMap;
47
48 // This only includes some preliminary checks as this is a short term solution.
AreEquivalent(mlir::func::FuncOp & lhs,mlir::func::FuncOp & rhs)49 bool AreEquivalent(mlir::func::FuncOp& lhs, mlir::func::FuncOp& rhs) {
50 if (lhs.getFunctionType() != rhs.getFunctionType()) return false;
51
52 for (auto arg_pair : llvm::zip(lhs.getArguments(), rhs.getArguments())) {
53 auto& lhs_arg = std::get<0>(arg_pair);
54 auto& rhs_arg = std::get<1>(arg_pair);
55 if (lhs_arg.getType() != rhs_arg.getType()) return false;
56 }
57
58 auto lhs_ops = lhs.getBody().getOps();
59 auto rhs_ops = rhs.getBody().getOps();
60 if (std::distance(lhs_ops.begin(), lhs_ops.end()) !=
61 std::distance(rhs_ops.begin(), rhs_ops.end()))
62 return false;
63
64 for (auto op_pair : llvm::zip(lhs_ops, rhs_ops)) {
65 auto& lhs_op = std::get<0>(op_pair);
66 auto& rhs_op = std::get<1>(op_pair);
67 if (lhs_op.getName() != rhs_op.getName()) return false;
68 if (lhs_op.getNumRegions() != rhs_op.getNumRegions()) return false;
69 if (lhs_op.getNumSuccessors() != rhs_op.getNumSuccessors()) return false;
70 if (!std::equal(lhs_op.getOperandTypes().begin(),
71 lhs_op.getOperandTypes().end(),
72 rhs_op.getOperandTypes().begin()))
73 return false;
74 if (!std::equal(lhs_op.getResultTypes().begin(),
75 lhs_op.getResultTypes().end(),
76 rhs_op.getResultTypes().begin()))
77 return false;
78 }
79
80 return true;
81 }
82
83 // Deduplicate the functions if all users are BatchFunctionOp and have the same
84 // shared_name.
85 //
86 // TODO(b/192463730): this is the short term solution and not needed anymore
87 // after the shape inference pass is revamped with ideal solution
88 // (b/192463730#comment11).
89 class DeduplicateFunctionsInovkedByBatchFunction
90 : public mlir::PassWrapper<DeduplicateFunctionsInovkedByBatchFunction,
91 mlir::OperationPass<mlir::ModuleOp>> {
92 public:
93 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
94 DeduplicateFunctionsInovkedByBatchFunction)
95
96 private:
getArgument() const97 llvm::StringRef getArgument() const final {
98 return "tfrt-deduplicate-functions-invoked-by-batch-function";
99 }
getDescription() const100 llvm::StringRef getDescription() const final {
101 return "Deduplicate the functions invoked by tf.BatchFunction with the "
102 "same shared_name";
103 }
runOnOperation()104 void runOnOperation() override {
105 if (failed(Run())) {
106 signalPassFailure();
107 }
108 }
109
110 mlir::LogicalResult Run();
111 };
112
Run()113 mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() {
114 ModuleOp module = getOperation();
115 SymbolTableCollection symbol_table_collection;
116 SymbolTable& symbol_table = symbol_table_collection.getSymbolTable(module);
117 SymbolUserMap symbol_users(symbol_table_collection, module);
118
119 // Categorize the functions invoked by BatchFunctionOp by its shared_name.
120 llvm::StringMap<llvm::SmallVector<mlir::func::FuncOp, 2>>
121 shared_name_to_func_ops;
122
123 for (auto func :
124 llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
125 ArrayRef<Operation*> users = symbol_users.getUsers(func);
126 llvm::StringRef shared_name;
127 // Deduplicate the function only if all users are BatchFunctionOp and have
128 // the same shared_name
129 if (!users.empty() && llvm::all_of(users, [&shared_name](Operation* user) {
130 auto op = llvm::dyn_cast_or_null<mlir::TF::BatchFunctionOp>(user);
131 // User is not a BatchFunctionOp
132 if (!op) return false;
133 if (shared_name.empty()) {
134 shared_name = op.shared_name();
135 return true;
136 }
137 return shared_name == op.shared_name();
138 })) {
139 shared_name_to_func_ops[shared_name].push_back(func);
140 }
141 }
142
143 for (auto& it : shared_name_to_func_ops) {
144 auto& func_ops = it.second;
145 mlir::func::FuncOp& func_op_to_keep = func_ops.front();
146 for (mlir::func::FuncOp& func_op_to_remove : llvm::drop_begin(func_ops)) {
147 if (!AreEquivalent(func_op_to_keep, func_op_to_remove)) {
148 return func_op_to_remove.emitError(
149 "func_ops for BatchFunctionOp with the same shared name are "
150 "different");
151 }
152 if (failed(SymbolTable::replaceAllSymbolUses(
153 func_op_to_remove, func_op_to_keep.getSymNameAttr(), module))) {
154 return func_op_to_remove.emitError("unable to replace the symbol use");
155 }
156 symbol_table.erase(func_op_to_remove);
157 }
158 }
159
160 return mlir::success();
161 }
162 } // namespace
163
164 std::unique_ptr<mlir::OperationPass<ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass()165 CreateDeduplicateFunctionsInovkedByBatchFunctionPass() {
166 return std::make_unique<DeduplicateFunctionsInovkedByBatchFunction>();
167 }
168
169 static mlir::PassRegistration<DeduplicateFunctionsInovkedByBatchFunction>
170 register_pass;
171
172 } // namespace tfrt_compiler
173 } // namespace tensorflow
174