1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass clusters the TensorFlow ops by host. The program generated by this
17 // pass will have one function per host where all operations in the same
18 // function are placed on the same host. Each result of the per-host function
19 // will have a "tf.device" attribute which specifies the device assignment of
20 // the result.
21 //
22 // The pass currently assumes that there is no circular dependency among the
23 // per-host functions. For example, if there exists an operation placed on
24 // host_A that consumes the result of an operation placed on host_B, then there
25 // does not exist any operation placed on host_B that conumes any result of any
26 // operation placed on host_A.
27
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace mlir {
40 namespace TF {
41 namespace {
42
43 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
44 using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
45
46 constexpr const char *kHostAttr = "host";
47 constexpr const char *kDeviceAttr = "device";
48 constexpr const char *kTFDeviceAttr = "tf.device";
49 // TODO(donglin): Handle the case where the address of localhost is different
50 // from /job:localhost/replica:0/task:0.
51 constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
52 constexpr const char *kErrorMessage =
53 "The operation that uses the operand is on a different host than the "
54 "operation that defines the op. This pass does not support cross-host data "
55 "transfer yet";
56
57 // The host address is identified by the job/replicate/task in the device name.
GetHost(llvm::StringRef device)58 std::string GetHost(llvm::StringRef device) {
59 ParsedName parsed_name;
60 DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
61 std::string result = DeviceNameUtils::ParsedNameToString(
62 DeviceNameUtils::AddressSpace(parsed_name));
63 return result.empty() ? kLocalhost : result;
64 }
65
GetHost(Operation * op)66 std::string GetHost(Operation *op) {
67 std::string device = "";
68 if (StringAttr attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
69 device = attr.getValue().str();
70 }
71 return GetHost(device);
72 }
73
74 // The device is considered to be on the localhost iff one of the following is
75 // true:
76 // 1) None of the job/replica/task is specified in the device name.
77 // 2) The job/replica/task in the device name are explicitly specified as
78 // /job:localhost/replica:0/task:0.
IsOnLocalHost(llvm::StringRef device)79 bool IsOnLocalHost(llvm::StringRef device) {
80 std::string host = GetHost(device);
81 return host == kLocalhost;
82 }
83
84 // This structure contains the metadata of the per-host function. All operations
85 // in this function should be on the same host.
86 struct FunctionMetadata {
87 // The original function name before partition.
88 llvm::StringRef original_name;
89 // The insertion point of partition functions.
90 Block::iterator insertion_point;
91 // The partitioned function name.
92 llvm::StringRef partition_name;
93 // The input values of the function.
94 llvm::SmallVector<Value, 4> inputs;
95 // The result values of the function.
96 llvm::SmallVector<Value, 4> results;
97 // The devices of the input values. It should have the same size as inputs.
98 llvm::SmallVector<std::string, 4> input_devices;
99 // The devices of the result values. It should have the same size as results.
100 llvm::SmallVector<std::string, 4> result_devices;
101 // The operations to be included in the body of the function.
102 llvm::SmallVector<Operation *, 4> ops;
103
104 func::FuncOp partition_op;
105 };
106
107 // Returns a map that maps the host address to the metadata of the function
108 // for that remote host. The metadata of the function specifies the input
109 // values, result values, result devices and the operations to be included in
110 // the function body.
GetFunctionMetadatas(func::FuncOp func_op)111 llvm::Optional<llvm::StringMap<FunctionMetadata>> GetFunctionMetadatas(
112 func::FuncOp func_op) {
113 llvm::StringMap<FunctionMetadata> metadatas;
114 WalkResult result = func_op.getBody().walk([&](Operation *op) {
115 std::string op_host = GetHost(op);
116 FunctionMetadata &func_metadata = metadatas[op_host];
117 func_metadata.original_name = func_op.getName();
118 func_metadata.insertion_point = ++Block::iterator(func_op);
119 func_metadata.ops.push_back(op);
120
121 for (Value value : op->getOperands()) {
122 std::string value_device = "";
123
124 // If the value is defined as an argument of the func_op, adds it to
125 // the argument list of the function that uses this op.
126 if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
127 if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
128 block_arg.getArgNumber(), kTFDeviceAttr)) {
129 value_device = attr.getValue().str();
130 }
131
132 if (GetHost(value_device) != op_host) {
133 op->emitOpError() << kErrorMessage;
134 return WalkResult::interrupt();
135 }
136
137 if (llvm::find(func_metadata.inputs, value) ==
138 func_metadata.inputs.end()) {
139 func_metadata.inputs.push_back(value);
140 func_metadata.input_devices.push_back(value_device);
141 }
142 continue;
143 }
144
145 Operation *defining_op = value.getDefiningOp();
146 std::string defining_op_host = GetHost(defining_op);
147 FunctionMetadata &defining_func_metadata = metadatas[defining_op_host];
148
149 if (StringAttr attr =
150 defining_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
151 value_device = attr.getValue().str();
152 }
153
154 // If the value is used as an operand of the terminator op, adds it to
155 // the result list of function that defines this op.
156 if (op->hasTrait<OpTrait::IsTerminator>()) {
157 if (llvm::find(defining_func_metadata.results, value) ==
158 defining_func_metadata.results.end()) {
159 defining_func_metadata.results.push_back(value);
160 defining_func_metadata.result_devices.push_back(value_device);
161 }
162 continue;
163 }
164
165 if (defining_op_host != op_host) {
166 op->emitOpError() << kErrorMessage;
167 return WalkResult::interrupt();
168 }
169 }
170 return WalkResult::advance();
171 });
172
173 if (result.wasInterrupted()) return llvm::None;
174
175 return metadatas;
176 }
177
178 // Creates functions in the given module using the given FunctionMetadatas.
CreateFunctions(ModuleOp module_op,llvm::StringMap<FunctionMetadata> & metadatas)179 void CreateFunctions(ModuleOp module_op,
180 llvm::StringMap<FunctionMetadata> &metadatas) {
181 MLIRContext *context = module_op.getContext();
182 SymbolTable symbol_table(module_op);
183 for (auto &iter : metadatas) {
184 llvm::StringRef host = iter.first();
185 FunctionMetadata &metadata = iter.second;
186
187 // Do not create any new function for the operations on the localhost.
188 if (IsOnLocalHost(host)) continue;
189
190 llvm::SmallVector<mlir::Type, 4> input_types;
191 llvm::SmallVector<mlir::Type, 4> result_types;
192 for (Value input : metadata.inputs) {
193 input_types.push_back(input.getType());
194 }
195 for (Value result : metadata.results) {
196 result_types.push_back(result.getType());
197 }
198
199 // Replaces ':' and '/' with '_' in the host name and uses the resulting
200 // string as the function name.
201 std::string func_name =
202 absl::StrCat(iter.second.original_name.str(), ":", host.str());
203 std::replace(func_name.begin(), func_name.end(), ':', '_');
204 std::replace(func_name.begin(), func_name.end(), '/', '_');
205
206 FunctionType func_type =
207 FunctionType::get(context, input_types, result_types);
208 Location loc = metadata.ops.front()->getLoc();
209 func::FuncOp func_op = func::FuncOp::create(loc, func_name, func_type);
210 // Sets the device attribute for every input and every result of the
211 // function.
212 for (int i : llvm::seq<int>(0, metadata.input_devices.size())) {
213 func_op.setArgAttr(i, kTFDeviceAttr,
214 StringAttr::get(context, metadata.input_devices[i]));
215 }
216 for (int i : llvm::seq<int>(0, metadata.result_devices.size())) {
217 func_op.setResultAttr(
218 i, kTFDeviceAttr,
219 StringAttr::get(context, metadata.result_devices[i]));
220 }
221
222 func_op->setAttr(kHostAttr, StringAttr::get(context, host));
223 func_op.setPublic();
224 Block *block = func_op.addEntryBlock();
225
226 // Clones and moves the operations into the function's body. And the cloned
227 // operation should use the arguments of the newly created func_op as
228 // appropriate.
229 OpBuilder builder(block, block->end());
230 BlockAndValueMapping mapping;
231 for (int i : llvm::seq<int>(0, metadata.inputs.size())) {
232 Value original_value = metadata.inputs[i];
233 Value new_value = func_op.getArgument(i);
234 mapping.map(original_value, new_value);
235 }
236 for (Operation *op : metadata.ops) {
237 builder.clone(*op, mapping);
238 }
239 // Creates the ReturnOp so that the per-host function returns the
240 // correct values of the cloned operations.
241 llvm::SmallVector<Value, 4> results_after_mapping;
242 for (Value result : metadata.results) {
243 results_after_mapping.push_back(mapping.lookupOrDefault(result));
244 }
245 builder.create<func::ReturnOp>(loc, results_after_mapping);
246 symbol_table.insert(func_op, metadata.insertion_point++);
247 // Record the actual name. The symbol table might rename the FuncOp if there
248 // is name collision.
249 metadata.partition_name = func_op.getName();
250 }
251 }
252
253 // Creates a tf_device.remote_run call for every remote function. And replaces
254 // usages of the results of the original operations with the results of the
255 // tf_device.remote_run calls.
CreateRemoteRunCalls(MLIRContext * context,const llvm::StringMap<FunctionMetadata> & metadatas)256 void CreateRemoteRunCalls(MLIRContext *context,
257 const llvm::StringMap<FunctionMetadata> &metadatas) {
258 BlockAndValueMapping mapping;
259 for (auto &iter : metadatas) {
260 llvm::StringRef host = iter.first();
261 const FunctionMetadata &metadata = iter.second;
262
263 // Do not create tf_device.remote_run call for the operations already placed
264 // on the localhost.
265 if (IsOnLocalHost(host)) continue;
266
267 // Creates the tf_device.remote_run operation.
268 OpBuilder builder(metadata.ops.back());
269 llvm::SmallVector<Type, 4> result_types;
270 for (Value result : metadata.results) {
271 result_types.push_back(result.getType());
272 }
273 Location loc = metadata.ops.front()->getLoc();
274 llvm::SmallVector<Value, 4> inputs_after_mapping;
275 for (Value input : metadata.inputs) {
276 inputs_after_mapping.push_back(mapping.lookupOrDefault(input));
277 }
278
279 tf_device::RemoteRunOp remote_run_op =
280 builder.create<tf_device::RemoteRunOp>(loc, result_types, host,
281 metadata.partition_name,
282 inputs_after_mapping);
283 // Clones the tf_device.remote_run operation to replace its callee args with
284 // the results of the other tf_device.remote_run operations using the
285 // `mapping` as appropriate.
286 Operation *cloned_remote_run_op =
287 builder.clone(*remote_run_op.getOperation(), mapping);
288 remote_run_op.erase();
289
290 // Replaces usages of the results of the original operations with the
291 // results of the tf_device.remote_run operations.
292 for (int i : llvm::seq<int>(0, metadata.results.size())) {
293 Value original_value = metadata.results[i];
294 Value new_value = cloned_remote_run_op->getResult(i);
295 original_value.replaceAllUsesWith(new_value);
296 mapping.map(original_value, new_value);
297 }
298 }
299 }
300
301 class ClusterTFOpsByHostPass
302 : public ClusterTFOpsByHostPassBase<ClusterTFOpsByHostPass> {
runOnOperation()303 void runOnOperation() override {
304 MLIRContext *context = &getContext();
305 ModuleOp module_op = getOperation();
306 SmallVector<func::FuncOp, 4> original_func;
307 for (auto func_op : module_op.getOps<func::FuncOp>()) {
308 original_func.push_back(func_op);
309 }
310 for (auto func_op : original_func) {
311 llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
312 GetFunctionMetadatas(func_op);
313 if (!metadatas) {
314 signalPassFailure();
315 return;
316 }
317
318 CreateFunctions(module_op, *metadatas);
319 CreateRemoteRunCalls(context, *metadatas);
320
321 // Erases the original operations which have been cloned in the remote
322 // functions.
323 for (auto &iter : *metadatas) {
324 llvm::StringRef host = iter.first();
325 FunctionMetadata &metadata = iter.second;
326 // Do not erase operations placed on the localhost.
327 if (IsOnLocalHost(host)) continue;
328
329 for (int i = metadata.ops.size() - 1; i >= 0; i--) {
330 metadata.ops[i]->erase();
331 }
332 }
333 }
334 }
335 };
336
337 } // namespace
338
CreateClusterTFOpsByHostPass()339 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass() {
340 return std::make_unique<ClusterTFOpsByHostPass>();
341 }
342
343 } // namespace TF
344 } // namespace mlir
345