xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass inserts corert.transfer op to make sure any argument of any op is
17 // on the same device of the op itself.
18 
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/Types.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "mlir/Transforms/Passes.h"  // from @llvm-project
27 #include "tensorflow/core/util/device_name_utils.h"
28 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
29 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
38 
39 constexpr const char *kDeviceAttr = "device";
40 constexpr const char *kTFRTDeviceAttr = "tfrt.device";
41 // TODO(b/175480458): Do not assign default device once every op in the TF
42 // dialect has the device attribute.
43 constexpr const char *kDefaultDevice =
44     "/job:localhost/replica:0/task:0/device:CPU:0";
45 
46 // This method canonicalizes the device name so that we can use string
47 // comparison to see if two devices are the same. It does the following
48 // transformations:
49 // 1) Set device ID to 0 if device ID is not already specified.
50 // 2) Change the device type to uppercase string.
CanonicalizeDeviceName(const std::string & device)51 static std::string CanonicalizeDeviceName(const std::string &device) {
52   if (device.empty()) return kDefaultDevice;
53 
54   DeviceNameUtils::ParsedName parsed_name;
55   if (!device.empty() && device.at(0) == '/') {
56     DeviceNameUtils::ParseFullName(device, &parsed_name);
57   } else {
58     DeviceNameUtils::ParseFullName("/device:" + device, &parsed_name);
59   }
60 
61   if (!parsed_name.has_id) {
62     parsed_name.has_id = true;
63     parsed_name.id = 0;
64   }
65 
66   if (parsed_name.type == "cpu")
67     parsed_name.type = "CPU";
68   else if (parsed_name.type == "gpu")
69     parsed_name.type = "GPU";
70   else if (parsed_name.type == "tpu")
71     parsed_name.type = "TPU";
72   return DeviceNameUtils::ParsedNameToString(parsed_name);
73 }
74 
75 // Return the device of the given operation.
GetDevice(Operation * op)76 static std::string GetDevice(Operation *op) {
77   std::string device = "";
78   if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
79     device = device_attr.getValue().str();
80   } else if (auto execute_op = llvm::dyn_cast<tfrt::corert::ExecuteOp>(op)) {
81     SmallVector<std::pair<StringRef, Attribute>, 4> attrs;
82     execute_op.getOpAttrs(&attrs);
83     for (std::pair<StringRef, Attribute> entry : attrs) {
84       if (entry.first == kDeviceAttr && entry.second.isa<StringAttr>()) {
85         device = entry.second.cast<StringAttr>().getValue().str();
86         break;
87       }
88     }
89   }
90 
91   return CanonicalizeDeviceName(device);
92 }
93 
94 // Return the device of the given value.
GetDevice(mlir::Value value,func::FuncOp parent_func_op)95 static std::string GetDevice(mlir::Value value, func::FuncOp parent_func_op) {
96   std::string device = "";
97   if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
98     if (StringAttr device_attr = parent_func_op.getArgAttrOfType<StringAttr>(
99             block_arg.getArgNumber(), kTFRTDeviceAttr)) {
100       device = device_attr.getValue().str();
101     }
102   } else {
103     device = GetDevice(value.getDefiningOp());
104   }
105 
106   return CanonicalizeDeviceName(device);
107 }
108 
109 struct CrossDeviceTransferPass
110     : public PassWrapper<CrossDeviceTransferPass, OperationPass<func::FuncOp>> {
111   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CrossDeviceTransferPass)
112 
113   void runOnOperation() override;
114 
getArgumenttensorflow::__anon4aa4a9b60111::CrossDeviceTransferPass115   llvm::StringRef getArgument() const final {
116     return "tfrt-cross-device-transfer";
117   }
118 
getDescriptiontensorflow::__anon4aa4a9b60111::CrossDeviceTransferPass119   llvm::StringRef getDescription() const final {
120     return "This pass inserts corert.transfer op to make sure any argument of "
121            "any op is on the same device of the op itself.";
122   }
123 };
124 
runOnOperation()125 void CrossDeviceTransferPass::runOnOperation() {
126   func::FuncOp func_op = getOperation();
127   llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
128       transferred_value_by_value_and_device;
129 
130   func_op.getBody().walk([&](Operation *op) {
131     if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
132     // Do not transfer the argument of corert.transfer op.
133     if (llvm::isa<tfrt::corert::TransferOp>(op)) return WalkResult::advance();
134 
135     OpBuilder builder(op);
136     std::string dst_device = GetDevice(op);
137     mlir::Type tensor_type_type =
138         builder.getType<::tfrt::compiler::TensorTypeType>();
139     mlir::Type device_type = builder.getType<::tfrt::compiler::DeviceType>();
140 
141     for (mlir::Value arg : op->getOperands()) {
142       // Do not transfer non-TensorHandle values.
143       if (!arg.getType().isa<tfrt::corert::TensorHandleType>()) continue;
144 
145       // Do not transfer the result of corert.transfer op.
146       if (OpResult op_result = arg.dyn_cast<OpResult>()) {
147         Operation *defining_op = arg.getDefiningOp();
148         if (llvm::isa<tfrt::corert::TransferOp>(defining_op)) continue;
149       }
150 
151       std::string src_device = GetDevice(arg, func_op);
152 
153       if (DeviceNameUtils::LocalName(src_device) ==
154           DeviceNameUtils::LocalName(dst_device))
155         continue;
156 
157       // Re-use the value already transferred to the given device.
158       llvm::StringMap<mlir::Value> &transferred_value_by_device =
159           transferred_value_by_value_and_device[arg];
160       auto iter = transferred_value_by_device.find(dst_device);
161       if (iter != transferred_value_by_device.end()) {
162         op->replaceUsesOfWith(arg, iter->second);
163         continue;
164       }
165 
166       mlir::Value chain_in = func_op.getArgument(0);
167       auto get_device_op = builder.create<tfrt::compiler::GetDeviceOp>(
168           op->getLoc(), device_type, chain_in, dst_device);
169       auto get_tensor_type_op =
170           builder.create<tfrt::corert::GetDstTensorTypeOp>(
171               op->getLoc(), tensor_type_type, arg, get_device_op.getResult());
172       auto transfer_op = builder.create<tfrt::corert::TransferOp>(
173           op->getLoc(), arg.getType(), arg, get_device_op.getResult(),
174           get_tensor_type_op.getResult());
175       mlir::Value new_arg = transfer_op.getResult();
176       transferred_value_by_device[dst_device] = new_arg;
177       op->replaceUsesOfWith(arg, new_arg);
178     }
179     return WalkResult::advance();
180   });
181 }
182 
183 }  // namespace
184 
CreateCrossDeviceTransferPass()185 std::unique_ptr<OperationPass<func::FuncOp>> CreateCrossDeviceTransferPass() {
186   return std::make_unique<CrossDeviceTransferPass>();
187 }
188 
189 static PassRegistration<CrossDeviceTransferPass> pass;
190 
191 }  // namespace tensorflow
192