1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "mlir/Analysis/CallGraph.h"  // from @llvm-project
19 #include "mlir/Dialect/Affine/Utils.h"  // from @llvm-project
20 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
21 #include "mlir/Pass/Pass.h"  // from @llvm-project
22 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
23 #include "mlir/Support/LLVM.h"  // from @llvm-project
24 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
27 
28 namespace mlir {
29 namespace TF {
30 
31 namespace {
32 
33 // Check that there is no recursion in the module's call graph.
CheckNoRecursion(ModuleOp module,CallGraph & call_graph)34 LogicalResult CheckNoRecursion(ModuleOp module, CallGraph &call_graph) {
35   for (llvm::scc_iterator<const CallGraph *> scci =
36            llvm::scc_begin<const CallGraph *>(&call_graph);
37        !scci.isAtEnd(); ++scci) {
38     if (scci.hasCycle()) {
39       auto err = module.emitError()
40                  << "A recursive call graph cannot be transformed to "
41                     "one use for all functions. Functions in the "
42                     "recursive cycle are: ";
43       llvm::interleaveComma(*scci, err, [&](CallGraphNode *node) {
44         err << node->getCallableRegion()->getLoc();
45       });
46       return err;
47     }
48   }
49   return success();
50 }
51 
52 // Clones FuncOp's until they have a single use only (or no users).
53 //
54 // The tf-shape-inference pass doesn't support functions that have more than
55 // a single use. But some real code from frontends does end up creating code
56 // like that. For example, the same LSTM cell function or loop body function
57 // will be reused.
58 //
59 // This pass clones functions as needed to establish the invariant that all
60 // functions have a single use. This can in principle cause exponential code
61 // size bloat, and should in general be guided by a proper cost model.
62 //
63 // There are two factors which should be considered by a principled replacement
64 // to this pass:
65 //
66 // 1. TF currently relies on "sufficiently good shape inference" for
67 // correctness so for now the cost of doing this seems acceptable since
68 // pathological cases haven't hit us yet.
69 //
70 // 2. Cloning functions can help by allowing code to be specialized (much as
71 // inlining does). In fact, tf-shape-inference attempts to do specialization
72 // of callees which is difficult if callees have multiple uses.
73 class GuaranteeAllFuncsOneUse
74     : public GuaranteeAllFuncsOneUsePassBase<GuaranteeAllFuncsOneUse> {
75  public:
runOnOperation()76   void runOnOperation() override {
77     if (failed(Run())) {
78       signalPassFailure();
79     }
80   }
81 
Run()82   LogicalResult Run() {
83     auto module = getOperation();
84 
85     // Overall strategy:
86     // Fixed point iteration, iteratively applying a rule that clones
87     // any FuncOp with more than one use to eliminate its uses.
88     SymbolTableCollection symbol_table_collection;
89     SymbolTable &symbol_table = symbol_table_collection.getSymbolTable(module);
90     bool made_changes = false;
91 
92     if (failed(CheckNoRecursion(module, getAnalysis<CallGraph>())))
93       return failure();
94 
95     do {
96       SymbolUserMap symbol_users(symbol_table_collection, module);
97 
98       made_changes = false;
99       for (auto func :
100            llvm::make_early_inc_range(module.getOps<func::FuncOp>())) {
101         ArrayRef<Operation *> users = symbol_users.getUsers(func);
102         if (users.size() <= 1) {
103           continue;
104         }
105 
106         // At this point, we know we are going to change the module.
107         made_changes = true;
108         for (Operation *user : users.drop_front()) {
109           func::FuncOp new_func = func.clone();
110           symbol_table.insert(new_func);
111           new_func.setPrivate();
112           if (failed(SymbolTable::replaceAllSymbolUses(
113                   func, new_func.getSymNameAttr(), user))) {
114             return func.emitError() << "could not replace symbol use";
115           }
116         }
117       }
118     } while (made_changes);
119 
120     return success();
121   }
122 };
123 
124 }  // namespace
125 
CreateGuaranteeAllFuncsOneUsePass()126 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass() {
127   return std::make_unique<GuaranteeAllFuncsOneUse>();
128 }
129 
130 }  // namespace TF
131 
132 }  // namespace mlir
133