1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <queue>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
20 #include "mlir/Pass/Pass.h"  // from @llvm-project
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes_detail.h"
27 
28 namespace mlir {
29 namespace TFDevice {
30 namespace {
31 
32 constexpr char kBadDecompositionMessage[] =
33     "Resource ops decomposition did not converge";
34 
35 // Decomposing resource ops should not take more than a few iterations (2-3) to
36 // converge as only a few patterns create new resource ops that can be further
37 // decomposed. The rest of the iterations are enough to clean up any dead ops
38 // created by decomposition.
39 constexpr int kMaxIterations = 10;
40 
41 // Populates `reachable_functions` with all functions that can be reached from
42 // device cluster ops.
PopulateClusterReachableFunctions(ModuleOp module,SmallPtrSetImpl<Operation * > & reachable_functions)43 void PopulateClusterReachableFunctions(
44     ModuleOp module, SmallPtrSetImpl<Operation*>& reachable_functions) {
45   SymbolTableCollection table;
46   SymbolUserMap symbol_map(table, module);
47 
48   // Create map from caller to set of all callee(s).
49   llvm::DenseMap<func::FuncOp, llvm::DenseSet<func::FuncOp>> caller_callee_map;
50 
51   // Use worklist to populate the set of reachable functions.
52   std::queue<func::FuncOp> function_worklist;
53 
54   // Iterates over all functions within the module to (1) create caller-callee
55   // map, and (2) initialize function worklist with functions referenced from
56   // device cluster ops.
57   for (auto func : module.getOps<func::FuncOp>()) {
58     for (auto user : symbol_map.getUsers(func)) {
59       // Populate caller-callee map.
60       if (func::FuncOp caller = user->getParentOfType<func::FuncOp>())
61         caller_callee_map[caller].insert(func);
62       // Initialize function worklist with functions refrerenced in device
63       // cluster.
64       if (auto cluster = user->getParentOfType<tf_device::ClusterOp>()) {
65         if (reachable_functions.insert(func).second)
66           function_worklist.push(func);
67       }
68     }
69   }
70 
71   // Uses worklist algorithm to insert all functions reachable from device
72   // cluster ops.
73   while (!function_worklist.empty()) {
74     func::FuncOp caller = function_worklist.front();
75     function_worklist.pop();
76     for (auto callee : caller_callee_map[caller]) {
77       if (reachable_functions.insert(callee).second)
78         function_worklist.push(callee);
79     }
80   }
81 }
82 
83 // Applies patterns locally on ops within `cluster` until convergence or
84 // `max_iterations` are reached. Returns failure if resource ops decomposition
85 // does not converge after `max_iterations`.
86 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
87 // Extract out as a separate utility.
ApplyPatternsLocallyUntilConverged(Operation * op_with_regions,FrozenRewritePatternSet & patterns,int max_iterations)88 LogicalResult ApplyPatternsLocallyUntilConverged(
89     Operation* op_with_regions, FrozenRewritePatternSet& patterns,
90     int max_iterations) {
91   bool changed = true;
92   int iteration = 0;
93   while (changed && (iteration++ < max_iterations)) {
94     changed = false;
95     auto walk_result =
96         op_with_regions->walk([&patterns, &changed](Operation* operation) {
97           bool op_changed;
98           if (failed(applyOpPatternsAndFold(operation, patterns, &op_changed)))
99             return WalkResult::interrupt();
100           changed |= op_changed;
101           return WalkResult::advance();
102         });
103     if (walk_result.wasInterrupted()) return failure();
104   }
105   // Return failure is `op_with_region` was modified changed in last iteration.
106   return success(!changed);
107 }
108 
109 // Applies patterns in only device clusters and functions reachable from such
110 // clusters. Returns failure if it fails to converge in `max_iterations`.
111 // TODO(prakalps): This can be useful to a lot of other passes in bridge.
112 // Extract out as a separate utility.
ApplyPatternsInClusterAndReachableFunctions(ModuleOp module,FrozenRewritePatternSet & patterns,int max_iterations)113 LogicalResult ApplyPatternsInClusterAndReachableFunctions(
114     ModuleOp module, FrozenRewritePatternSet& patterns, int max_iterations) {
115   SmallPtrSet<Operation*, 16> reachable_functions;
116   PopulateClusterReachableFunctions(module, reachable_functions);
117 
118   // Apply patterns to reachable functions.
119   for (Operation* op : reachable_functions) {
120     assert(isa<func::FuncOp>(op));
121     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
122       return op->emitError() << kBadDecompositionMessage;
123     }
124   }
125 
126   // Apply patterns to device cluster ops.
127   // Note: This module search for cluster ops is a bit wasteful as we could have
128   // collected many cluster ops when we were populating reachable functions. But
129   // we would still need to do a walk to find all clusters that do not
130   // reference any function.
131   for (func::FuncOp func : module.getOps<func::FuncOp>()) {
132     // If we have already applied patterns to a function then we can skip
133     // applying patterns to any device clusters it contains.
134     if (reachable_functions.contains(func)) continue;
135 
136     auto walk_result = func.walk([&](tf_device::ClusterOp cluster) {
137       // Cluster ops are not isolated from above so we cannot use
138       // `applyPatternsAndFoldGreedily` utility. Instead we apply patterns
139       // locally on each op within the cluster until convergence.
140       if (failed(ApplyPatternsLocallyUntilConverged(cluster, patterns,
141                                                     max_iterations))) {
142         cluster.emitError() << kBadDecompositionMessage;
143         return WalkResult::interrupt();
144       }
145       return WalkResult::advance();
146     });
147     if (walk_result.wasInterrupted()) return failure();
148   }
149 
150   return success();
151 }
152 
153 struct DecomposeResourceOpsPass
154     : public DecomposeResourceOpsPassBase<DecomposeResourceOpsPass> {
runOnOperationmlir::TFDevice::__anonbde900590111::DecomposeResourceOpsPass155   void runOnOperation() override {
156     // Add lowering patterns to the list.
157     RewritePatternSet patterns(&getContext());
158     TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
159 
160     if (failed(applyPatternsAndFoldGreedily(getOperation(),
161                                             std::move(patterns)))) {
162       getOperation().emitError() << kBadDecompositionMessage;
163       signalPassFailure();
164     }
165   }
166 };
167 
168 struct DecomposeResourceOpsInClusterPass
169     : public DecomposeResourceOpsInClusterPassBase<
170           DecomposeResourceOpsInClusterPass> {
runOnOperationmlir::TFDevice::__anonbde900590111::DecomposeResourceOpsInClusterPass171   void runOnOperation() override {
172     // Add lowering patterns to the list.
173     RewritePatternSet patterns(&getContext());
174     TF::PopulateDecomposeResourceOpsPatterns(&getContext(), &patterns);
175     FrozenRewritePatternSet frozen_patterns(std::move(patterns));
176 
177     if (failed(ApplyPatternsInClusterAndReachableFunctions(
178             getOperation(), frozen_patterns, kMaxIterations)))
179       signalPassFailure();
180   }
181 };
182 
183 }  // namespace
184 
CreateDecomposeResourceOpsPass()185 std::unique_ptr<OperationPass<func::FuncOp>> CreateDecomposeResourceOpsPass() {
186   return std::make_unique<DecomposeResourceOpsPass>();
187 }
188 
189 std::unique_ptr<OperationPass<ModuleOp>>
CreateDecomposeResourceOpsInClusterPass()190 CreateDecomposeResourceOpsInClusterPass() {
191   return std::make_unique<DecomposeResourceOpsInClusterPass>();
192 }
193 
194 }  // namespace TFDevice
195 }  // namespace mlir
196