1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "mlir/Analysis/CallGraph.h" // from @llvm-project
19 #include "mlir/Dialect/Affine/Utils.h" // from @llvm-project
20 #include "mlir/IR/SymbolTable.h" // from @llvm-project
21 #include "mlir/Pass/Pass.h" // from @llvm-project
22 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
23 #include "mlir/Support/LLVM.h" // from @llvm-project
24 #include "mlir/Support/LogicalResult.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27
28 namespace mlir {
29 namespace TF {
30
31 namespace {
32
33 // Check that there is no recursion in the module's call graph.
CheckNoRecursion(ModuleOp module,CallGraph & call_graph)34 LogicalResult CheckNoRecursion(ModuleOp module, CallGraph &call_graph) {
35 for (llvm::scc_iterator<const CallGraph *> scci =
36 llvm::scc_begin<const CallGraph *>(&call_graph);
37 !scci.isAtEnd(); ++scci) {
38 if (scci.hasCycle()) {
39 auto err = module.emitError()
40 << "A recursive call graph cannot be transformed to "
41 "one use for all functions. Functions in the "
42 "recursive cycle are: ";
43 llvm::interleaveComma(*scci, err, [&](CallGraphNode *node) {
44 err << node->getCallableRegion()->getLoc();
45 });
46 return err;
47 }
48 }
49 return success();
50 }
51
52 // Clones FuncOp's until they have a single use only (or no users).
53 //
54 // The tf-shape-inference pass doesn't support functions that have more than
55 // a single use. But some real code from frontends does end up creating code
56 // like that. For example, the same LSTM cell function or loop body function
57 // will be reused.
58 //
59 // This pass clones functions as needed to establish the invariant that all
60 // functions have a single use. This can in principle cause exponential code
61 // size bloat, and should in general be guided by a proper cost model.
62 //
63 // There are two factors which should be considered by a principled replacement
64 // to this pass:
65 //
66 // 1. TF currently relies on "sufficiently good shape inference" for
67 // correctness so for now the cost of doing this seems acceptable since
68 // pathological cases haven't hit us yet.
69 //
70 // 2. Cloning functions can help by allowing code to be specialized (much as
71 // inlining does). In fact, tf-shape-inference attempts to do specialization
72 // of callees which is difficult if callees have multiple uses.
73 class GuaranteeAllFuncsOneUse
74 : public GuaranteeAllFuncsOneUsePassBase<GuaranteeAllFuncsOneUse> {
75 public:
runOnOperation()76 void runOnOperation() override {
77 if (failed(Run())) {
78 signalPassFailure();
79 }
80 }
81
Run()82 LogicalResult Run() {
83 auto module = getOperation();
84
85 // Overall strategy:
86 // Fixed point iteration, iteratively applying a rule that clones
87 // any FuncOp with more than one use to eliminate its uses.
88 SymbolTableCollection symbol_table_collection;
89 SymbolTable &symbol_table = symbol_table_collection.getSymbolTable(module);
90 bool made_changes = false;
91
92 if (failed(CheckNoRecursion(module, getAnalysis<CallGraph>())))
93 return failure();
94
95 do {
96 SymbolUserMap symbol_users(symbol_table_collection, module);
97
98 made_changes = false;
99 for (auto func :
100 llvm::make_early_inc_range(module.getOps<func::FuncOp>())) {
101 ArrayRef<Operation *> users = symbol_users.getUsers(func);
102 if (users.size() <= 1) {
103 continue;
104 }
105
106 // At this point, we know we are going to change the module.
107 made_changes = true;
108 for (Operation *user : users.drop_front()) {
109 func::FuncOp new_func = func.clone();
110 symbol_table.insert(new_func);
111 new_func.setPrivate();
112 if (failed(SymbolTable::replaceAllSymbolUses(
113 func, new_func.getSymNameAttr(), user))) {
114 return func.emitError() << "could not replace symbol use";
115 }
116 }
117 }
118 } while (made_changes);
119
120 return success();
121 }
122 };
123
124 } // namespace
125
CreateGuaranteeAllFuncsOneUsePass()126 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass() {
127 return std::make_unique<GuaranteeAllFuncsOneUse>();
128 }
129
130 } // namespace TF
131
132 } // namespace mlir
133