1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass hoists replicate invariant ops, or ops that yield the same
17 // result(s) regardless of replication, out of their respective replicate.
18
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32
33 namespace mlir {
34 namespace TFDevice {
35
36 namespace {
37
38 constexpr char kDeviceAttr[] = "device";
39
40 struct ReplicateInvariantOpHoistingPass
41 : public TF::ReplicateInvariantOpHoistingPassBase<
42 ReplicateInvariantOpHoistingPass> {
43 void runOnOperation() override;
44 };
45
MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op,int num_replicas,Block * replicate_block,TF::ShapeOp shape_op)46 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
47 Block* replicate_block, TF::ShapeOp shape_op) {
48 Value input = shape_op.input();
49 // If ShapeOp operand is replicate tensor block argument, replace with the
50 // associated first replica operand.
51 if (auto block_arg = input.dyn_cast<BlockArgument>()) {
52 if (block_arg.getOwner() != replicate_block) return;
53
54 shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument(
55 block_arg, /*replica=*/0));
56
57 return;
58 }
59
60 Operation* input_def = input.getDefiningOp();
61
62 // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
63 // operand is a replicate resource block argument, replace ShapeOp with
64 // VariableShapeOp and use the associated first replica operand as its
65 // operand.
66 auto read_var_op = llvm::dyn_cast<TF::ReadVariableOp>(input_def);
67 if (!read_var_op) return;
68
69 // TODO(lyandy): Check if resource (first replica or replicate block arg)
70 // shape has not changed in replicate prior to read. Currently after both
71 // ResourceOpLiftingPass and TPURewritePass, there should not be any updates
72 // to resources prior to their respective ReadVariableOp.
73 if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
74 if (block_arg.getOwner() != replicate_block) return;
75
76 OpBuilder builder(shape_op);
77 auto new_shape_op = builder.create<TF::VariableShapeOp>(
78 shape_op.getLoc(), shape_op.getType(),
79 replicate_op.GetReplicaOperandForBlockArgument(block_arg,
80 /*replica=*/0));
81 shape_op.replaceAllUsesWith(new_shape_op.getOperation());
82 shape_op.erase();
83 }
84 }
85
86 // Check if op uses a device from a list of virtual devices.
UsesVirtualDevice(const Optional<DictionaryAttr> & virtual_devices,Operation * operation)87 bool UsesVirtualDevice(const Optional<DictionaryAttr>& virtual_devices,
88 Operation* operation) {
89 if (!virtual_devices.has_value()) return false;
90
91 auto result = operation->walk([&](Operation* op) {
92 StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
93 if (!op_device) return WalkResult::advance();
94
95 if (virtual_devices.getValue().get(op_device.getValue()))
96 return WalkResult::interrupt();
97 return WalkResult::advance();
98 });
99 return result.wasInterrupted();
100 }
101
102 // Checks if op and inner op operands are all replicate invariant.
IsOpReplicateInvariant(Region * replicate_region,Operation * op)103 bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
104 auto ancestor_of_replicate = [&](Region* region) {
105 return region && region->isProperAncestor(replicate_region);
106 };
107
108 for (Value operand : op->getOperands())
109 if (!ancestor_of_replicate(operand.getParentRegion())) return false;
110
111 bool has_replicate_operands = false;
112 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
113 if (!ancestor_of_replicate(operand->get().getParentRegion()))
114 has_replicate_operands = true;
115 });
116
117 return !has_replicate_operands;
118 }
119
120 // Hoists replicate invariant ops out of associated `tf_device.replicate` op.
121 // Ops to be hoisted are determined by if all of their operands are replicate
122 // invariant. Shape ops are rewritten to be invariant when possible, prior to
123 // hoisting ops.
HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op)124 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) {
125 const int num_replicas = replicate_op.n();
126 Block* replicate_block = &replicate_op.GetBody();
127
128 replicate_op.walk([&](TF::ShapeOp shape_op) {
129 MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op);
130 });
131
132 Region* replicate_region = &replicate_op.body();
133 Optional<DictionaryAttr> virtual_device_list = replicate_op.devices();
134 for (Operation& inner_op :
135 llvm::make_early_inc_range(replicate_op.GetBody())) {
136 if (llvm::isa<tf_device::ReturnOp>(inner_op)) continue;
137 // Skip hoisting if the inner op device attribute is a virtual device
138 // defined by tf_device.replicate.
139 if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue;
140
141 if (IsOpReplicateInvariant(replicate_region, &inner_op))
142 inner_op.moveBefore(replicate_op);
143 }
144 }
145
runOnOperation()146 void ReplicateInvariantOpHoistingPass::runOnOperation() {
147 getOperation().walk(
148 [](tf_device::ReplicateOp op) { HoistReplicateInvariantOps(op); });
149 }
150 } // anonymous namespace
151
152 std::unique_ptr<OperationPass<func::FuncOp>>
CreateReplicateInvariantOpHoistingPass()153 CreateReplicateInvariantOpHoistingPass() {
154 return std::make_unique<ReplicateInvariantOpHoistingPass>();
155 }
156
157 } // namespace TFDevice
158 } // namespace mlir
159