1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Block.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace mlir {
37 namespace TFDevice {
38
39 namespace {
40
41 struct ClusterFormationPass
42 : public TF::ClusterFormationPassBase<ClusterFormationPass> {
43 void runOnOperation() override;
44 };
45
46 // Cluster structure captures all the operations that are assigned to same
47 // device and can form a legal strict cluster.
48 // Ops must follow same ordering in their parent block. We rely on this
49 // assumption to perform analysis.
50 struct Cluster {
51 llvm::SmallVector<Operation*, 4> ops;
52 StringRef device;
53 };
54
GetDevice(Operation * op)55 StringRef GetDevice(Operation* op) {
56 auto device_attr = op->getAttrOfType<StringAttr>("device");
57 return device_attr ? device_attr.getValue() : "";
58 }
59
60 // An op can be merged into cluster if all of its operands are one of the
61 // following:
62 // 1) A block argument
63 // 2) A value produced by other islands
64 // 1) Defined before the cluster
65 // 2) Defined by an operation in the cluster
66 // TODO(ycao): This is not optimal as it doesn't consider the situation of
67 // defining_op's operands all meet the requirements above. In that case, the
68 // defining_op can be moved and to_merge op would be legal to absorb.
69 // TODO(ycao): Take op side-effects into consideration since they can not be
70 // re-ordered but forming clusters of non-continuous ops is effectively
71 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)72 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
73 return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
74 // Block arguments.
75 if (operand.isa<BlockArgument>()) return true;
76
77 Operation* defining_op = operand.getDefiningOp();
78
79 // Operand produced by other islands.
80 if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
81
82 // Defining op is before the cluster.
83 if (defining_op->isBeforeInBlock(c.ops.front())) return true;
84
85 // Defining op is between first and last operation in cluster. Note that
86 // cluster may contain operations that are non-continuous in their original
87 // block, thus we also need to check defining_op is also assigned to
88 // cluster's device to be sure. This is a faster check than linearly
89 // searching through all ops in cluster.
90 if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
91 GetDevice(defining_op) == c.device)
92 return true;
93
94 // Other cases, operand is generated after or outside the cluster, this
95 // means it is illegal to merge operation.
96 return false;
97 });
98 }
99
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)100 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
101 tf_device::LaunchOp launch_op) {
102 Region* launch_op_region = &launch_op.body();
103 for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
104 Value from = std::get<0>(p);
105 // TODO(jingpu): move this to RegionUtils.h in MLIR core.
106 for (auto& use : llvm::make_early_inc_range(from.getUses())) {
107 if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
108 continue;
109 use.set(std::get<1>(p));
110 }
111 }
112 }
113
114 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)115 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
116 live_outs->clear();
117
118 for (Operation& op : region->front()) {
119 for (Value v : op.getResults()) {
120 // A value is live-out if any of its users are not inside value producer's
121 // region.
122 bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
123 return !region->isAncestor(user->getParentRegion());
124 });
125
126 if (is_live_out) live_outs->emplace_back(v);
127 }
128 }
129 }
130
131 // Build a `tf_device.launch` op with a region that contains all the operations
132 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)133 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
134 // Set insertion point to right after all operations in cluster.
135 builder->setInsertionPoint(c.ops.back()->getNextNode());
136
137 // Create a stand-alone region to hold all instructions in the cluster.
138 Region region;
139 region.push_back(new Block);
140
141 // Move all operations in cluster to newly created region, stripping their
142 // "device" attribute since launch op already carries device information.
143 Block* block = ®ion.front();
144 for (Operation* op : c.ops) {
145 op->moveBefore(block, block->end());
146 op->removeAttr(builder->getStringAttr("device"));
147 }
148
149 // Get all escaped live-out values of region, they are used later to determine
150 // return values and types of launch op.
151 llvm::SmallVector<Value, 4> live_outs;
152 GetLiveOuts(®ion, &live_outs);
153
154 // Build a `tf_device.return` op at end of region, with all live-out values
155 // as operand.
156 OpBuilder return_builder(builder->getContext());
157 return_builder.setInsertionPointToEnd(block);
158 return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
159 live_outs);
160
161 llvm::SmallVector<Type, 4> live_out_types;
162 live_out_types.reserve(live_outs.size());
163 for (Value v : live_outs) {
164 live_out_types.emplace_back(v.getType());
165 }
166
167 tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
168 builder->getUnknownLoc(), builder->getStringAttr(c.device),
169 live_out_types);
170
171 // Attach the region to launch_op.
172 launch_op.body().takeBody(region);
173
174 // Replace any external uses of live-out values with return values of launch
175 // op. So live-out values no longer escape the region.
176 ReplaceLiveOutExternalUses(live_outs, launch_op);
177 }
178
BuildClusters(Block * block,OpBuilder builder)179 void BuildClusters(Block* block, OpBuilder builder) {
180 // Iteratively find clusters of different devices within an island.
181 // Whenever we see an operation that is assigned to an accelerator device
182 // (ie. device != ""), we try to merge it into the last cluster of same
183 // device. If that is infeasible (say because of violating def-before-use),
184 // create a new cluster with that operation and move on.
185 llvm::MapVector<StringRef, Cluster> nearest_clusters;
186 for (Operation& op : llvm::make_early_inc_range(*block)) {
187 auto device = GetDevice(&op);
188 if (device.empty()) continue;
189
190 // If no cluster of same device has been formed yet, create a new cluster
191 // with op alone.
192 auto it = nearest_clusters.find(device);
193 if (it == nearest_clusters.end()) {
194 nearest_clusters[device] = Cluster{{&op}, device};
195 continue;
196 }
197
198 // Check if it is legal to merge op into nearest cluster of same device.
199 // If positive, update cluster and move on to next operation.
200 Cluster& nearest_cluster = it->second;
201 if (CanMergeIntoCluster(nearest_cluster, &op)) {
202 nearest_cluster.ops.emplace_back(&op);
203 continue;
204 }
205
206 // If nearest cluster of same device can not absorb `op`, then that
207 // cluster needs to be finalized by building a `tf_device.launch` op with
208 // a region that contains all operations in clusters.
209 BuildLaunchForCluster(nearest_cluster, &builder);
210
211 // Create a new cluster to hold op alone and update nearest_clusters.
212 nearest_clusters[device] = Cluster{{&op}, device};
213 }
214
215 // At the end, there might be left-over found clusters that need to be
216 // built.
217 for (auto& device_cluster : nearest_clusters)
218 BuildLaunchForCluster(device_cluster.second, &builder);
219 }
220
runOnOperation()221 void ClusterFormationPass::runOnOperation() {
222 auto func = getOperation();
223 if (func.isExternal()) return;
224 OpBuilder builder(func.getContext());
225
226 // Operates on individual blocks independently of if they are directly in the
227 // function body or if they are nested in individual `tf_executor.island`.
228 for (Block& block : func.getBody()) BuildClusters(&block, builder);
229 func.walk([&](tf_executor::IslandOp island) {
230 BuildClusters(&island.GetBody(), builder);
231 });
232 }
233
234 } // namespace
235
CreateClusterFormationPass()236 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterFormationPass() {
237 return std::make_unique<ClusterFormationPass>();
238 }
239
240 } // namespace TFDevice
241 } // namespace mlir
242