xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation forms clusters from instructions in same island and
17 // assigned to save devices. Clusters are represented as regions.
18 // Note that side-effecting ops are not correctly handled yet.
19 
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Block.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace mlir {
37 namespace TFDevice {
38 
39 namespace {
40 
41 struct ClusterFormationPass
42     : public TF::ClusterFormationPassBase<ClusterFormationPass> {
43   void runOnOperation() override;
44 };
45 
46 // Cluster structure captures all the operations that are assigned to same
47 // device and can form a legal strict cluster.
48 // Ops must follow same ordering in their parent block. We rely on this
49 // assumption to perform analysis.
50 struct Cluster {
51   llvm::SmallVector<Operation*, 4> ops;
52   StringRef device;
53 };
54 
GetDevice(Operation * op)55 StringRef GetDevice(Operation* op) {
56   auto device_attr = op->getAttrOfType<StringAttr>("device");
57   return device_attr ? device_attr.getValue() : "";
58 }
59 
60 // An op can be merged into cluster if all of its operands are one of the
61 // following:
62 //  1) A block argument
63 //  2) A value produced by other islands
64 //  1) Defined before the cluster
65 //  2) Defined by an operation in the cluster
66 // TODO(ycao): This is not optimal as it doesn't consider the situation of
67 // defining_op's operands all meet the requirements above. In that case, the
68 // defining_op can be moved and to_merge op would be legal to absorb.
69 // TODO(ycao): Take op side-effects into consideration since they can not be
70 // re-ordered but forming clusters of non-continuous ops is effectively
71 // re-ordering them..
CanMergeIntoCluster(const Cluster & c,Operation * to_merge)72 bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
73   return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
74     // Block arguments.
75     if (operand.isa<BlockArgument>()) return true;
76 
77     Operation* defining_op = operand.getDefiningOp();
78 
79     // Operand produced by other islands.
80     if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
81 
82     // Defining op is before the cluster.
83     if (defining_op->isBeforeInBlock(c.ops.front())) return true;
84 
85     // Defining op is between first and last operation in cluster. Note that
86     // cluster may contain operations that are non-continuous in their original
87     // block, thus we also need to check defining_op is also assigned to
88     // cluster's device to be sure. This is a faster check than linearly
89     // searching through all ops in cluster.
90     if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) &&
91         GetDevice(defining_op) == c.device)
92       return true;
93 
94     // Other cases, operand is generated after or outside the cluster, this
95     // means it is illegal to merge operation.
96     return false;
97   });
98 }
99 
ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,tf_device::LaunchOp launch_op)100 void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
101                                 tf_device::LaunchOp launch_op) {
102   Region* launch_op_region = &launch_op.body();
103   for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
104     Value from = std::get<0>(p);
105     // TODO(jingpu): move this to RegionUtils.h in MLIR core.
106     for (auto& use : llvm::make_early_inc_range(from.getUses())) {
107       if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
108         continue;
109       use.set(std::get<1>(p));
110     }
111   }
112 }
113 
114 // Get all escaped live-out values of a region.
GetLiveOuts(Region * region,llvm::SmallVectorImpl<Value> * live_outs)115 void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
116   live_outs->clear();
117 
118   for (Operation& op : region->front()) {
119     for (Value v : op.getResults()) {
120       // A value is live-out if any of its users are not inside value producer's
121       // region.
122       bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
123         return !region->isAncestor(user->getParentRegion());
124       });
125 
126       if (is_live_out) live_outs->emplace_back(v);
127     }
128   }
129 }
130 
131 // Build a `tf_device.launch` op with a region that contains all the operations
132 // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`.
BuildLaunchForCluster(const Cluster & c,OpBuilder * builder)133 void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
134   // Set insertion point to right after all operations in cluster.
135   builder->setInsertionPoint(c.ops.back()->getNextNode());
136 
137   // Create a stand-alone region to hold all instructions in the cluster.
138   Region region;
139   region.push_back(new Block);
140 
141   // Move all operations in cluster to newly created region, stripping their
142   // "device" attribute since launch op already carries device information.
143   Block* block = &region.front();
144   for (Operation* op : c.ops) {
145     op->moveBefore(block, block->end());
146     op->removeAttr(builder->getStringAttr("device"));
147   }
148 
149   // Get all escaped live-out values of region, they are used later to determine
150   // return values and types of launch op.
151   llvm::SmallVector<Value, 4> live_outs;
152   GetLiveOuts(&region, &live_outs);
153 
154   // Build a `tf_device.return` op at end of region, with all live-out values
155   // as operand.
156   OpBuilder return_builder(builder->getContext());
157   return_builder.setInsertionPointToEnd(block);
158   return_builder.create<tf_device::ReturnOp>(return_builder.getUnknownLoc(),
159                                              live_outs);
160 
161   llvm::SmallVector<Type, 4> live_out_types;
162   live_out_types.reserve(live_outs.size());
163   for (Value v : live_outs) {
164     live_out_types.emplace_back(v.getType());
165   }
166 
167   tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
168       builder->getUnknownLoc(), builder->getStringAttr(c.device),
169       live_out_types);
170 
171   // Attach the region to launch_op.
172   launch_op.body().takeBody(region);
173 
174   // Replace any external uses of live-out values with return values of launch
175   // op. So live-out values no longer escape the region.
176   ReplaceLiveOutExternalUses(live_outs, launch_op);
177 }
178 
BuildClusters(Block * block,OpBuilder builder)179 void BuildClusters(Block* block, OpBuilder builder) {
180   // Iteratively find clusters of different devices within an island.
181   // Whenever we see an operation that is assigned to an accelerator device
182   // (ie. device != ""), we try to merge it into the last cluster of same
183   // device. If that is infeasible (say because of violating def-before-use),
184   // create a new cluster with that operation and move on.
185   llvm::MapVector<StringRef, Cluster> nearest_clusters;
186   for (Operation& op : llvm::make_early_inc_range(*block)) {
187     auto device = GetDevice(&op);
188     if (device.empty()) continue;
189 
190     // If no cluster of same device has been formed yet, create a new cluster
191     // with op alone.
192     auto it = nearest_clusters.find(device);
193     if (it == nearest_clusters.end()) {
194       nearest_clusters[device] = Cluster{{&op}, device};
195       continue;
196     }
197 
198     // Check if it is legal to merge op into nearest cluster of same device.
199     // If positive, update cluster and move on to next operation.
200     Cluster& nearest_cluster = it->second;
201     if (CanMergeIntoCluster(nearest_cluster, &op)) {
202       nearest_cluster.ops.emplace_back(&op);
203       continue;
204     }
205 
206     // If nearest cluster of same device can not absorb `op`, then that
207     // cluster needs to be finalized by building a `tf_device.launch` op with
208     // a region that contains all operations in clusters.
209     BuildLaunchForCluster(nearest_cluster, &builder);
210 
211     // Create a new cluster to hold op alone and update nearest_clusters.
212     nearest_clusters[device] = Cluster{{&op}, device};
213   }
214 
215   // At the end, there might be left-over found clusters that need to be
216   // built.
217   for (auto& device_cluster : nearest_clusters)
218     BuildLaunchForCluster(device_cluster.second, &builder);
219 }
220 
runOnOperation()221 void ClusterFormationPass::runOnOperation() {
222   auto func = getOperation();
223   if (func.isExternal()) return;
224   OpBuilder builder(func.getContext());
225 
226   // Operates on individual blocks independently of if they are directly in the
227   // function body or if they are nested in individual `tf_executor.island`.
228   for (Block& block : func.getBody()) BuildClusters(&block, builder);
229   func.walk([&](tf_executor::IslandOp island) {
230     BuildClusters(&island.GetBody(), builder);
231   });
232 }
233 
234 }  // namespace
235 
CreateClusterFormationPass()236 std::unique_ptr<OperationPass<func::FuncOp>> CreateClusterFormationPass() {
237   return std::make_unique<ClusterFormationPass>();
238 }
239 
240 }  // namespace TFDevice
241 }  // namespace mlir
242