xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass clusters the TensorFlow ops by host. The program generated by this
17 // pass will have one function per host where all operations in the same
18 // function are placed on the same host. Each result of the per-host function
19 // will have a "tf.device" attribute which specifies the device assignment of
20 // the result.
21 //
22 // The pass currently assumes that there is no circular dependency among the
23 // per-host functions. For example, if there exists an operation placed on
24 // host_A that consumes the result of an operation placed on host_B, then there
25 // does not exist any operation placed on host_B that conumes any result of any
26 // operation placed on host_A.
27 
28 #include "mlir/IR/Builders.h"
29 #include "mlir/Pass/Pass.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
33 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace mlir {
40 namespace TF {
41 namespace {
42 
43 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
44 using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName;
45 
46 constexpr const char *kHostAttr = "host";
47 constexpr const char *kDeviceAttr = "device";
48 constexpr const char *kTFDeviceAttr = "tf.device";
49 // TODO(donglin): Handle the case where the address of localhost is different
50 // from /job:localhost/replica:0/task:0.
51 constexpr const char *kLocalhost = "/job:localhost/replica:0/task:0";
52 constexpr const char *kErrorMessage =
53     "The operation that uses the operand is on a different host than the "
54     "operation that defines the op. This pass does not support cross-host data "
55     "transfer yet";
56 
57 // The host address is identified by the job/replicate/task in the device name.
GetHost(llvm::StringRef device)58 std::string GetHost(llvm::StringRef device) {
59   ParsedName parsed_name;
60   DeviceNameUtils::ParseFullName(device.str(), &parsed_name);
61   std::string result = DeviceNameUtils::ParsedNameToString(
62       DeviceNameUtils::AddressSpace(parsed_name));
63   return result.empty() ? kLocalhost : result;
64 }
65 
GetHost(Operation * op)66 std::string GetHost(Operation *op) {
67   std::string device = "";
68   if (StringAttr attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
69     device = attr.getValue().str();
70   }
71   return GetHost(device);
72 }
73 
74 // The device is considered to be on the localhost iff one of the following is
75 // true:
76 // 1) None of the job/replica/task is specified in the device name.
77 // 2) The job/replica/task in the device name are explicitly specified as
78 //    /job:localhost/replica:0/task:0.
IsOnLocalHost(llvm::StringRef device)79 bool IsOnLocalHost(llvm::StringRef device) {
80   std::string host = GetHost(device);
81   return host == kLocalhost;
82 }
83 
84 // This structure contains the metadata of the per-host function. All operations
85 // in this function should be on the same host.
86 struct FunctionMetadata {
87   // The original function name before partition.
88   llvm::StringRef original_name;
89   // The insertion point of partition functions.
90   Block::iterator insertion_point;
91   // The partitioned function name.
92   llvm::StringRef partition_name;
93   // The input values of the function.
94   llvm::SmallVector<Value, 4> inputs;
95   // The result values of the function.
96   llvm::SmallVector<Value, 4> results;
97   // The devices of the input values. It should have the same size as inputs.
98   llvm::SmallVector<std::string, 4> input_devices;
99   // The devices of the result values. It should have the same size as results.
100   llvm::SmallVector<std::string, 4> result_devices;
101   // The operations to be included in the body of the function.
102   llvm::SmallVector<Operation *, 4> ops;
103 
104   func::FuncOp partition_op;
105 };
106 
107 // Returns a map that maps the host address to the metadata of the function
108 // for that remote host. The metadata of the function specifies the input
109 // values, result values, result devices and the operations to be included in
110 // the function body.
GetFunctionMetadatas(func::FuncOp func_op)111 llvm::Optional<llvm::StringMap<FunctionMetadata>> GetFunctionMetadatas(
112     func::FuncOp func_op) {
113   llvm::StringMap<FunctionMetadata> metadatas;
114   WalkResult result = func_op.getBody().walk([&](Operation *op) {
115     std::string op_host = GetHost(op);
116     FunctionMetadata &func_metadata = metadatas[op_host];
117     func_metadata.original_name = func_op.getName();
118     func_metadata.insertion_point = ++Block::iterator(func_op);
119     func_metadata.ops.push_back(op);
120 
121     for (Value value : op->getOperands()) {
122       std::string value_device = "";
123 
124       // If the value is defined as an argument of the func_op, adds it to
125       // the argument list of the function that uses this op.
126       if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
127         if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
128                 block_arg.getArgNumber(), kTFDeviceAttr)) {
129           value_device = attr.getValue().str();
130         }
131 
132         if (GetHost(value_device) != op_host) {
133           op->emitOpError() << kErrorMessage;
134           return WalkResult::interrupt();
135         }
136 
137         if (llvm::find(func_metadata.inputs, value) ==
138             func_metadata.inputs.end()) {
139           func_metadata.inputs.push_back(value);
140           func_metadata.input_devices.push_back(value_device);
141         }
142         continue;
143       }
144 
145       Operation *defining_op = value.getDefiningOp();
146       std::string defining_op_host = GetHost(defining_op);
147       FunctionMetadata &defining_func_metadata = metadatas[defining_op_host];
148 
149       if (StringAttr attr =
150               defining_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
151         value_device = attr.getValue().str();
152       }
153 
154       // If the value is used as an operand of the terminator op, adds it to
155       // the result list of function that defines this op.
156       if (op->hasTrait<OpTrait::IsTerminator>()) {
157         if (llvm::find(defining_func_metadata.results, value) ==
158             defining_func_metadata.results.end()) {
159           defining_func_metadata.results.push_back(value);
160           defining_func_metadata.result_devices.push_back(value_device);
161         }
162         continue;
163       }
164 
165       if (defining_op_host != op_host) {
166         op->emitOpError() << kErrorMessage;
167         return WalkResult::interrupt();
168       }
169     }
170     return WalkResult::advance();
171   });
172 
173   if (result.wasInterrupted()) return llvm::None;
174 
175   return metadatas;
176 }
177 
178 // Creates functions in the given module using the given FunctionMetadatas.
CreateFunctions(ModuleOp module_op,llvm::StringMap<FunctionMetadata> & metadatas)179 void CreateFunctions(ModuleOp module_op,
180                      llvm::StringMap<FunctionMetadata> &metadatas) {
181   MLIRContext *context = module_op.getContext();
182   SymbolTable symbol_table(module_op);
183   for (auto &iter : metadatas) {
184     llvm::StringRef host = iter.first();
185     FunctionMetadata &metadata = iter.second;
186 
187     // Do not create any new function for the operations on the localhost.
188     if (IsOnLocalHost(host)) continue;
189 
190     llvm::SmallVector<mlir::Type, 4> input_types;
191     llvm::SmallVector<mlir::Type, 4> result_types;
192     for (Value input : metadata.inputs) {
193       input_types.push_back(input.getType());
194     }
195     for (Value result : metadata.results) {
196       result_types.push_back(result.getType());
197     }
198 
199     // Replaces ':' and '/' with '_' in the host name and uses the resulting
200     // string as the function name.
201     std::string func_name =
202         absl::StrCat(iter.second.original_name.str(), ":", host.str());
203     std::replace(func_name.begin(), func_name.end(), ':', '_');
204     std::replace(func_name.begin(), func_name.end(), '/', '_');
205 
206     FunctionType func_type =
207         FunctionType::get(context, input_types, result_types);
208     Location loc = metadata.ops.front()->getLoc();
209     func::FuncOp func_op = func::FuncOp::create(loc, func_name, func_type);
210     // Sets the device attribute for every input and every result of the
211     // function.
212     for (int i : llvm::seq<int>(0, metadata.input_devices.size())) {
213       func_op.setArgAttr(i, kTFDeviceAttr,
214                          StringAttr::get(context, metadata.input_devices[i]));
215     }
216     for (int i : llvm::seq<int>(0, metadata.result_devices.size())) {
217       func_op.setResultAttr(
218           i, kTFDeviceAttr,
219           StringAttr::get(context, metadata.result_devices[i]));
220     }
221 
222     func_op->setAttr(kHostAttr, StringAttr::get(context, host));
223     func_op.setPublic();
224     Block *block = func_op.addEntryBlock();
225 
226     // Clones and moves the operations into the function's body. And the cloned
227     // operation should use the arguments of the newly created func_op as
228     // appropriate.
229     OpBuilder builder(block, block->end());
230     BlockAndValueMapping mapping;
231     for (int i : llvm::seq<int>(0, metadata.inputs.size())) {
232       Value original_value = metadata.inputs[i];
233       Value new_value = func_op.getArgument(i);
234       mapping.map(original_value, new_value);
235     }
236     for (Operation *op : metadata.ops) {
237       builder.clone(*op, mapping);
238     }
239     // Creates the ReturnOp so that the per-host function returns the
240     // correct values of the cloned operations.
241     llvm::SmallVector<Value, 4> results_after_mapping;
242     for (Value result : metadata.results) {
243       results_after_mapping.push_back(mapping.lookupOrDefault(result));
244     }
245     builder.create<func::ReturnOp>(loc, results_after_mapping);
246     symbol_table.insert(func_op, metadata.insertion_point++);
247     // Record the actual name. The symbol table might rename the FuncOp if there
248     // is name collision.
249     metadata.partition_name = func_op.getName();
250   }
251 }
252 
253 // Creates a tf_device.remote_run call for every remote function. And replaces
254 // usages of the results of the original operations with the results of the
255 // tf_device.remote_run calls.
CreateRemoteRunCalls(MLIRContext * context,const llvm::StringMap<FunctionMetadata> & metadatas)256 void CreateRemoteRunCalls(MLIRContext *context,
257                           const llvm::StringMap<FunctionMetadata> &metadatas) {
258   BlockAndValueMapping mapping;
259   for (auto &iter : metadatas) {
260     llvm::StringRef host = iter.first();
261     const FunctionMetadata &metadata = iter.second;
262 
263     // Do not create tf_device.remote_run call for the operations already placed
264     // on the localhost.
265     if (IsOnLocalHost(host)) continue;
266 
267     // Creates the tf_device.remote_run operation.
268     OpBuilder builder(metadata.ops.back());
269     llvm::SmallVector<Type, 4> result_types;
270     for (Value result : metadata.results) {
271       result_types.push_back(result.getType());
272     }
273     Location loc = metadata.ops.front()->getLoc();
274     llvm::SmallVector<Value, 4> inputs_after_mapping;
275     for (Value input : metadata.inputs) {
276       inputs_after_mapping.push_back(mapping.lookupOrDefault(input));
277     }
278 
279     tf_device::RemoteRunOp remote_run_op =
280         builder.create<tf_device::RemoteRunOp>(loc, result_types, host,
281                                                metadata.partition_name,
282                                                inputs_after_mapping);
283     // Clones the tf_device.remote_run operation to replace its callee args with
284     // the results of the other tf_device.remote_run operations using the
285     // `mapping` as appropriate.
286     Operation *cloned_remote_run_op =
287         builder.clone(*remote_run_op.getOperation(), mapping);
288     remote_run_op.erase();
289 
290     // Replaces usages of the results of the original operations with the
291     // results of the tf_device.remote_run operations.
292     for (int i : llvm::seq<int>(0, metadata.results.size())) {
293       Value original_value = metadata.results[i];
294       Value new_value = cloned_remote_run_op->getResult(i);
295       original_value.replaceAllUsesWith(new_value);
296       mapping.map(original_value, new_value);
297     }
298   }
299 }
300 
301 class ClusterTFOpsByHostPass
302     : public ClusterTFOpsByHostPassBase<ClusterTFOpsByHostPass> {
runOnOperation()303   void runOnOperation() override {
304     MLIRContext *context = &getContext();
305     ModuleOp module_op = getOperation();
306     SmallVector<func::FuncOp, 4> original_func;
307     for (auto func_op : module_op.getOps<func::FuncOp>()) {
308       original_func.push_back(func_op);
309     }
310     for (auto func_op : original_func) {
311       llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
312           GetFunctionMetadatas(func_op);
313       if (!metadatas) {
314         signalPassFailure();
315         return;
316       }
317 
318       CreateFunctions(module_op, *metadatas);
319       CreateRemoteRunCalls(context, *metadatas);
320 
321       // Erases the original operations which have been cloned in the remote
322       // functions.
323       for (auto &iter : *metadatas) {
324         llvm::StringRef host = iter.first();
325         FunctionMetadata &metadata = iter.second;
326         // Do not erase operations placed on the localhost.
327         if (IsOnLocalHost(host)) continue;
328 
329         for (int i = metadata.ops.size() - 1; i >= 0; i--) {
330           metadata.ops[i]->erase();
331         }
332       }
333     }
334   }
335 };
336 
337 }  // namespace
338 
CreateClusterTFOpsByHostPass()339 std::unique_ptr<OperationPass<mlir::ModuleOp>> CreateClusterTFOpsByHostPass() {
340   return std::make_unique<ClusterTFOpsByHostPass>();
341 }
342 
343 }  // namespace TF
344 }  // namespace mlir
345