1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Value.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
29 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
30
31 namespace mlir {
32 namespace TFTPU {
33 namespace {
34
35 // Pass that co-locates resource ops that use composite device resources
36 // (packed tensors) with the underlying physical TPU device.
37 struct TPUColocateCompositeResourceOps
38 : public TF::TPUColocateCompositeResourceOpsPassBase<
39 TPUColocateCompositeResourceOps> {
40 void runOnOperation() override;
41 };
42
43 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)44 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
45 llvm::StringRef device) {
46 builder->setInsertionPoint(op);
47 auto launch = builder->create<tf_device::LaunchOp>(
48 loc, builder->getStringAttr(device), op->getResultTypes());
49 launch.body().push_back(new Block);
50 op->replaceAllUsesWith(launch);
51
52 builder->setInsertionPointToEnd(&launch.GetBody());
53 builder->create<tf_device::ReturnOp>(loc, op->getResults());
54
55 // Move op inside cluster.
56 op->moveBefore(launch.GetBody().getTerminator());
57 }
58
GetResourceOpsUsingCompositeArgsInReplicate(tf_device::ReplicateOp replicate)59 llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
60 tf_device::ReplicateOp replicate) {
61 llvm::SmallVector<Operation*, 4> resource_users;
62 const auto add_resource_op_to_list = [&resource_users](Operation* op) {
63 if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
64
65 resource_users.emplace_back(op);
66 };
67
68 llvm::SmallVector<Operation*, 4> resource_users_to_visit;
69 for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
70 for (auto resource_user : composite_arguments.getUsers())
71 resource_users_to_visit.emplace_back(resource_user);
72 }
73
74 while (!resource_users_to_visit.empty()) {
75 llvm::SmallVector<Operation*, 4> new_resource_users;
76
77 for (auto resource_user : resource_users_to_visit) {
78 add_resource_op_to_list(resource_user);
79
80 // Account for pass-through identity ops.
81 if (auto pass_through_identity =
82 llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
83 for (auto identity_user : pass_through_identity.output().getUsers()) {
84 new_resource_users.emplace_back(identity_user);
85 }
86 }
87 }
88 resource_users_to_visit.swap(new_resource_users);
89 }
90
91 return resource_users;
92 }
93
ColocateCompositeResourceOpsInReplicate(tf_device::ReplicateOp replicate_op,OpBuilder * builder)94 void ColocateCompositeResourceOpsInReplicate(
95 tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
96 auto devices = replicate_op.devices();
97 if (!devices) return;
98 if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
99 return;
100
101 const auto composite_resource_users =
102 GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
103 for (auto resource_user : composite_resource_users) {
104 WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
105 tensorflow::GetDeviceAliasForLogicalCore(0));
106 }
107 }
108
runOnOperation()109 void TPUColocateCompositeResourceOps::runOnOperation() {
110 // Find all the executes first, since we will mutate the nodes around each
111 // execute in the same tf_device.replicate op.
112 llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
113 getOperation().walk([&](tf_device::LaunchOp op) {
114 if (op.WrapsSingleOp() &&
115 llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
116 op.GetBody().front()))
117 execute_launches.push_back(op);
118 });
119
120 OpBuilder builder(&getContext());
121 for (auto execute_launch : execute_launches) {
122 auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
123 if (!replicate) continue;
124
125 ColocateCompositeResourceOpsInReplicate(replicate, &builder);
126 }
127 }
128
129 } // namespace
130
131 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUColocateCompositeResourceOps()132 CreateTPUColocateCompositeResourceOps() {
133 return std::make_unique<TPUColocateCompositeResourceOps>();
134 }
135
136 } // namespace TFTPU
137 } // namespace mlir
138