1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
28 #include "mlir/IR/Operation.h"  // from @llvm-project
29 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
30 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
31 #include "mlir/Support/LLVM.h"  // from @llvm-project
32 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
33 #include "mlir/Transforms/Passes.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
36 
37 namespace tensorflow {
38 namespace tfrt_compiler {
39 namespace {
40 
41 using ::mlir::ArrayRef;
42 using ::mlir::ModuleOp;
43 using ::mlir::Operation;
44 using ::mlir::SymbolTable;
45 using ::mlir::SymbolTableCollection;
46 using ::mlir::SymbolUserMap;
47 
48 // This only includes some preliminary checks as this is a short term solution.
AreEquivalent(mlir::func::FuncOp & lhs,mlir::func::FuncOp & rhs)49 bool AreEquivalent(mlir::func::FuncOp& lhs, mlir::func::FuncOp& rhs) {
50   if (lhs.getFunctionType() != rhs.getFunctionType()) return false;
51 
52   for (auto arg_pair : llvm::zip(lhs.getArguments(), rhs.getArguments())) {
53     auto& lhs_arg = std::get<0>(arg_pair);
54     auto& rhs_arg = std::get<1>(arg_pair);
55     if (lhs_arg.getType() != rhs_arg.getType()) return false;
56   }
57 
58   auto lhs_ops = lhs.getBody().getOps();
59   auto rhs_ops = rhs.getBody().getOps();
60   if (std::distance(lhs_ops.begin(), lhs_ops.end()) !=
61       std::distance(rhs_ops.begin(), rhs_ops.end()))
62     return false;
63 
64   for (auto op_pair : llvm::zip(lhs_ops, rhs_ops)) {
65     auto& lhs_op = std::get<0>(op_pair);
66     auto& rhs_op = std::get<1>(op_pair);
67     if (lhs_op.getName() != rhs_op.getName()) return false;
68     if (lhs_op.getNumRegions() != rhs_op.getNumRegions()) return false;
69     if (lhs_op.getNumSuccessors() != rhs_op.getNumSuccessors()) return false;
70     if (!std::equal(lhs_op.getOperandTypes().begin(),
71                     lhs_op.getOperandTypes().end(),
72                     rhs_op.getOperandTypes().begin()))
73       return false;
74     if (!std::equal(lhs_op.getResultTypes().begin(),
75                     lhs_op.getResultTypes().end(),
76                     rhs_op.getResultTypes().begin()))
77       return false;
78   }
79 
80   return true;
81 }
82 
83 // Deduplicate the functions if all users are BatchFunctionOp and have the same
84 // shared_name.
85 //
86 // TODO(b/192463730): this is the short term solution and not needed anymore
87 // after the shape inference pass is revamped with ideal solution
88 // (b/192463730#comment11).
89 class DeduplicateFunctionsInovkedByBatchFunction
90     : public mlir::PassWrapper<DeduplicateFunctionsInovkedByBatchFunction,
91                                mlir::OperationPass<mlir::ModuleOp>> {
92  public:
93   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
94       DeduplicateFunctionsInovkedByBatchFunction)
95 
96  private:
getArgument() const97   llvm::StringRef getArgument() const final {
98     return "tfrt-deduplicate-functions-invoked-by-batch-function";
99   }
getDescription() const100   llvm::StringRef getDescription() const final {
101     return "Deduplicate the functions invoked by tf.BatchFunction with the "
102            "same shared_name";
103   }
runOnOperation()104   void runOnOperation() override {
105     if (failed(Run())) {
106       signalPassFailure();
107     }
108   }
109 
110   mlir::LogicalResult Run();
111 };
112 
Run()113 mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() {
114   ModuleOp module = getOperation();
115   SymbolTableCollection symbol_table_collection;
116   SymbolTable& symbol_table = symbol_table_collection.getSymbolTable(module);
117   SymbolUserMap symbol_users(symbol_table_collection, module);
118 
119   // Categorize the functions invoked by BatchFunctionOp by its shared_name.
120   llvm::StringMap<llvm::SmallVector<mlir::func::FuncOp, 2>>
121       shared_name_to_func_ops;
122 
123   for (auto func :
124        llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
125     ArrayRef<Operation*> users = symbol_users.getUsers(func);
126     llvm::StringRef shared_name;
127     // Deduplicate the function only if all users are BatchFunctionOp and have
128     // the same shared_name
129     if (!users.empty() && llvm::all_of(users, [&shared_name](Operation* user) {
130           auto op = llvm::dyn_cast_or_null<mlir::TF::BatchFunctionOp>(user);
131           // User is not a BatchFunctionOp
132           if (!op) return false;
133           if (shared_name.empty()) {
134             shared_name = op.shared_name();
135             return true;
136           }
137           return shared_name == op.shared_name();
138         })) {
139       shared_name_to_func_ops[shared_name].push_back(func);
140     }
141   }
142 
143   for (auto& it : shared_name_to_func_ops) {
144     auto& func_ops = it.second;
145     mlir::func::FuncOp& func_op_to_keep = func_ops.front();
146     for (mlir::func::FuncOp& func_op_to_remove : llvm::drop_begin(func_ops)) {
147       if (!AreEquivalent(func_op_to_keep, func_op_to_remove)) {
148         return func_op_to_remove.emitError(
149             "func_ops for BatchFunctionOp with the same shared name are "
150             "different");
151       }
152       if (failed(SymbolTable::replaceAllSymbolUses(
153               func_op_to_remove, func_op_to_keep.getSymNameAttr(), module))) {
154         return func_op_to_remove.emitError("unable to replace the symbol use");
155       }
156       symbol_table.erase(func_op_to_remove);
157     }
158   }
159 
160   return mlir::success();
161 }
162 }  // namespace
163 
164 std::unique_ptr<mlir::OperationPass<ModuleOp>>
CreateDeduplicateFunctionsInovkedByBatchFunctionPass()165 CreateDeduplicateFunctionsInovkedByBatchFunctionPass() {
166   return std::make_unique<DeduplicateFunctionsInovkedByBatchFunction>();
167 }
168 
169 static mlir::PassRegistration<DeduplicateFunctionsInovkedByBatchFunction>
170     register_pass;
171 
172 }  // namespace tfrt_compiler
173 }  // namespace tensorflow
174