1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <queue>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/IR/SymbolTable.h" // from @llvm-project
20 #include "mlir/Pass/Pass.h" // from @llvm-project
21 #include "mlir/Support/LogicalResult.h" // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
27
28 namespace mlir {
29 namespace TFDevice {
30 namespace {
31
32 constexpr char kBadDecompositionMessage[] =
33 "Resource ops decomposition did not converge";
34
35 // Decomposing resource ops should not take more than a few iterations (2-3) to
36 // converge as only a few patterns create new resource ops that can be further
37 // decomposed. The rest of the iterations are enough to clean up any dead ops
38 // created by decomposition.
39 constexpr int kMaxIterations = 10;
40
41 // Populates `reachable_functions` with all functions that can be reached from
42 // device cluster ops.
PopulateClusterReachableFunctions(ModuleOp module,SmallPtrSetImpl<Operation * > & reachable_functions)43 void PopulateClusterReachableFunctions(
44 ModuleOp module, SmallPtrSetImpl<Operation*>& reachable_functions) {
45 SymbolTableCollection table;
46 SymbolUserMap symbol_map(table, module);
47
48 // Create map from caller to set of all callee(s).
49 llvm::DenseMap<func::FuncOp, llvm::DenseSet<func::FuncOp>> caller_callee_map;
50
51 // Use worklist to populate the set of reachable functions.
52 std::queue<func::FuncOp> function_worklist;
53
54 // Iterates over all functions within the module to (1) create caller-callee
55 // map, and (2) initialize function worklist with functions referenced from
56 // device cluster ops.
57 for (auto func : module.getOps<func::FuncOp>()) {
58 for (auto user : symbol_map.getUsers(func)) {
59 // Populate caller-callee map.
60 if (func::FuncOp caller = user->getParentOfType<func::FuncOp>())
61 caller_callee_map[caller].insert(func);
62 // Initialize function worklist with functions refrerenced in device
63 // cluster.
64 if (auto cluster = user->getParentOfType<tf_device::ClusterOp>()) {
65 if (reachable_functions.insert(func).second)
66 function_worklist.push(func);
67 }
68 }
69 }
70
71 // Uses worklist algorithm to insert all functions reachable from device
72 // cluster ops.
73 while (!function_worklist.empty()) {
74 func::FuncOp caller = function_worklist.front();
75 function_worklist.pop();
76 for (auto callee : caller_callee_map[caller]) {
77 if (reachable_functions.insert(callee).second)
78 function_worklist.push(callee);
79 }
80 }
81 }
82
83 // Applies patterns locally on ops within `cluster` until convergence or
84 // `max_iterations` are reached. Returns failure if resource ops decomposition
85 // does not converge after `max_iterations`.
86 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
87 // Extract out as a separate utility.
ApplyPatternsLocallyUntilConverged(Operation * op_with_regions,FrozenRewritePatternSet & patterns,int max_iterations)88 LogicalResult ApplyPatternsLocallyUntilConverged(
89 Operation* op_with_regions, FrozenRewritePatternSet& patterns,
90 int max_iterations) {
91 bool changed = true;
92 int iteration = 0;
93 while (changed && (iteration++ < max_iterations)) {
94 changed = false;
95 auto walk_result =
96 op_with_regions->walk([&patterns, &changed](Operation* operation) {
97 bool op_changed;
98 if (failed(applyOpPatternsAndFold(operation, patterns, &op_changed)))
99 return WalkResult::interrupt();
100 changed |= op_changed;
101 return WalkResult::advance();
102 });
103 if (walk_result.wasInterrupted()) return failure();
104 }
105 // Return failure is `op_with_region` was modified changed in last iteration.
106 return success(!changed);
107 }
108
109 // Applies patterns in only device clusters and functions reachable from such
110 // clusters. Returns failure if it fails to converge in `max_iterations`.
111 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
112 // Extract out as a separate utility.
ApplyPatternsInClusterAndReachableFunctions(ModuleOp module,FrozenRewritePatternSet & patterns,int max_iterations)113 LogicalResult ApplyPatternsInClusterAndReachableFunctions(
114 ModuleOp module, FrozenRewritePatternSet& patterns, int max_iterations) {
115 SmallPtrSet<Operation*, 16> reachable_functions;
116 PopulateClusterReachableFunctions(module, reachable_functions);
117
118 // Apply patterns to reachable functions.
119 for (Operation* op : reachable_functions) {
120 assert(isa<func::FuncOp>(op));
121 if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
122 return op->emitError() << kBadDecompositionMessage;
123 }
124 }
125
126 // Apply patterns to device cluster ops.
127 // Note: This module search for cluster ops is a bit wasteful as we could have
128 // collected many cluster ops when we were populating reachable functions. But
129 // we would still need to do a walk to find all clusters that do not
130 // reference any function.
131 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
132 // If we have already applied patterns to a function then we can skip
133 // applying patterns to any device clusters it contains.
134 if (reachable_functions.contains(func)) continue;
135
136 auto walk_result = func.walk([&](tf_device::ClusterOp cluster) {
137 // Cluster ops are not isolated from above so we cannot use
138 // `applyPatternsAndFoldGreedily` utility. Instead we apply patterns
139 // locally on each op within the cluster until convergence.
140 if (failed(ApplyPatternsLocallyUntilConverged(cluster, patterns,
141 max_iterations))) {
142 cluster.emitError() << kBadDecompositionMessage;
143 return WalkResult::interrupt();
144 }
145 return WalkResult::advance();
146 });
147 if (walk_result.wasInterrupted()) return failure();
148 }
149
150 return success();
151 }
152
153 struct DecomposeResourceOpsPass
154 : public DecomposeResourceOpsPassBase<DecomposeResourceOpsPass> {
runOnOperationmlir::TFDevice::__anonbde900590111::DecomposeResourceOpsPass155 void runOnOperation() override {
156 // Add lowering patterns to the list.
157 RewritePatternSet patterns(&getContext());
158 TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
159
160 if (failed(applyPatternsAndFoldGreedily(getOperation(),
161 std::move(patterns)))) {
162 getOperation().emitError() << kBadDecompositionMessage;
163 signalPassFailure();
164 }
165 }
166 };
167
168 struct DecomposeResourceOpsInClusterPass
169 : public DecomposeResourceOpsInClusterPassBase<
170 DecomposeResourceOpsInClusterPass> {
runOnOperationmlir::TFDevice::__anonbde900590111::DecomposeResourceOpsInClusterPass171 void runOnOperation() override {
172 // Add lowering patterns to the list.
173 RewritePatternSet patterns(&getContext());
174 TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
175 FrozenRewritePatternSet frozen_patterns(std::move(patterns));
176
177 if (failed(ApplyPatternsInClusterAndReachableFunctions(
178 getOperation(), frozen_patterns, kMaxIterations)))
179 signalPassFailure();
180 }
181 };
182
183 } // namespace
184
CreateDecomposeResourceOpsPass()185 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeResourceOpsPass() {
186 return std::make_unique<DecomposeResourceOpsPass>();
187 }
188
189 std::unique_ptr<OperationPass<ModuleOp>>
CreateDecomposeResourceOpsInClusterPass()190 CreateDecomposeResourceOpsInClusterPass() {
191 return std::make_unique<DecomposeResourceOpsInClusterPass>();
192 }
193
194 } // namespace TFDevice
195 } // namespace mlir
196