1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass hoists replicate invariant ops, or ops that yield the same
17 // result(s) regardless of replication, out of their respective replicate.
18 
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 
33 namespace mlir {
34 namespace TFDevice {
35 
36 namespace {
37 
38 constexpr char kDeviceAttr[] = "device";
39 
40 struct ReplicateInvariantOpHoistingPass
41     : public TF::ReplicateInvariantOpHoistingPassBase<
42           ReplicateInvariantOpHoistingPass> {
43   void runOnOperation() override;
44 };
45 
MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op,int num_replicas,Block * replicate_block,TF::ShapeOp shape_op)46 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
47                           Block* replicate_block, TF::ShapeOp shape_op) {
48   Value input = shape_op.input();
49   // If ShapeOp operand is replicate tensor block argument, replace with the
50   // associated first replica operand.
51   if (auto block_arg = input.dyn_cast<BlockArgument>()) {
52     if (block_arg.getOwner() != replicate_block) return;
53 
54     shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument(
55         block_arg, /*replica=*/0));
56 
57     return;
58   }
59 
60   Operation* input_def = input.getDefiningOp();
61 
62   // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
63   // operand is a replicate resource block argument, replace ShapeOp with
64   // VariableShapeOp and use the associated first replica operand as its
65   // operand.
66   auto read_var_op = llvm::dyn_cast<TF::ReadVariableOp>(input_def);
67   if (!read_var_op) return;
68 
69   // TODO(lyandy): Check if resource (first replica or replicate block arg)
70   // shape has not changed in replicate prior to read. Currently after both
71   // ResourceOpLiftingPass and TPURewritePass, there should not be any updates
72   // to resources prior to their respective ReadVariableOp.
73   if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
74     if (block_arg.getOwner() != replicate_block) return;
75 
76     OpBuilder builder(shape_op);
77     auto new_shape_op = builder.create<TF::VariableShapeOp>(
78         shape_op.getLoc(), shape_op.getType(),
79         replicate_op.GetReplicaOperandForBlockArgument(block_arg,
80                                                        /*replica=*/0));
81     shape_op.replaceAllUsesWith(new_shape_op.getOperation());
82     shape_op.erase();
83   }
84 }
85 
86 // Check if op uses a device from a list of virtual devices.
UsesVirtualDevice(const Optional<DictionaryAttr> & virtual_devices,Operation * operation)87 bool UsesVirtualDevice(const Optional<DictionaryAttr>& virtual_devices,
88                        Operation* operation) {
89   if (!virtual_devices.has_value()) return false;
90 
91   auto result = operation->walk([&](Operation* op) {
92     StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
93     if (!op_device) return WalkResult::advance();
94 
95     if (virtual_devices.getValue().get(op_device.getValue()))
96       return WalkResult::interrupt();
97     return WalkResult::advance();
98   });
99   return result.wasInterrupted();
100 }
101 
102 // Checks if op and inner op operands are all replicate invariant.
IsOpReplicateInvariant(Region * replicate_region,Operation * op)103 bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
104   auto ancestor_of_replicate = [&](Region* region) {
105     return region && region->isProperAncestor(replicate_region);
106   };
107 
108   for (Value operand : op->getOperands())
109     if (!ancestor_of_replicate(operand.getParentRegion())) return false;
110 
111   bool has_replicate_operands = false;
112   visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
113     if (!ancestor_of_replicate(operand->get().getParentRegion()))
114       has_replicate_operands = true;
115   });
116 
117   return !has_replicate_operands;
118 }
119 
120 // Hoists replicate invariant ops out of associated `tf_device.replicate` op.
121 // Ops to be hoisted are determined by if all of their operands are replicate
122 // invariant. Shape ops are rewritten to be invariant when possible, prior to
123 // hoisting ops.
HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op)124 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) {
125   const int num_replicas = replicate_op.n();
126   Block* replicate_block = &replicate_op.GetBody();
127 
128   replicate_op.walk([&](TF::ShapeOp shape_op) {
129     MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op);
130   });
131 
132   Region* replicate_region = &replicate_op.body();
133   Optional<DictionaryAttr> virtual_device_list = replicate_op.devices();
134   for (Operation& inner_op :
135        llvm::make_early_inc_range(replicate_op.GetBody())) {
136     if (llvm::isa<tf_device::ReturnOp>(inner_op)) continue;
137     // Skip hoisting if the inner op device attribute is a virtual device
138     // defined by tf_device.replicate.
139     if (UsesVirtualDevice(virtual_device_list, &inner_op)) continue;
140 
141     if (IsOpReplicateInvariant(replicate_region, &inner_op))
142       inner_op.moveBefore(replicate_op);
143   }
144 }
145 
runOnOperation()146 void ReplicateInvariantOpHoistingPass::runOnOperation() {
147   getOperation().walk(
148       [](tf_device::ReplicateOp op) { HoistReplicateInvariantOps(op); });
149 }
150 }  // anonymous namespace
151 
152 std::unique_ptr<OperationPass<func::FuncOp>>
CreateReplicateInvariantOpHoistingPass()153 CreateReplicateInvariantOpHoistingPass() {
154   return std::make_unique<ReplicateInvariantOpHoistingPass>();
155 }
156 
157 }  // namespace TFDevice
158 }  // namespace mlir
159