1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18
19 #include <memory>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Block.h" // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/Diagnostics.h" // from @llvm-project
35 #include "mlir/IR/Dialect.h" // from @llvm-project
36 #include "mlir/IR/Visitors.h" // from @llvm-project
37 #include "mlir/Pass/Pass.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45
46 namespace mlir {
47 namespace TFDevice {
48 namespace {
49 constexpr char kDeviceAttr[] = "device";
50 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
51 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
52 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
53
54 struct ReplicateToIslandPass
55 : public TF::ReplicateToIslandPassBase<ReplicateToIslandPass> {
56 void runOnOperation() override;
57 };
58
59 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)60 bool RequiresReplicaIDAttribute(Operation* op) {
61 return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
62 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
63 TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp>(op);
64 }
65
66 // Collects TPU device ordinal for outside compilation communication ops. This
67 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
68 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)69 llvm::Optional<int64_t> GetDeviceOrdinal(
70 const llvm::Optional<DictionaryAttr>& devices, Location loc,
71 unsigned replica_id) {
72 int64_t device_ordinal = 0;
73 if (devices.has_value()) {
74 if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
75 llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
76 .cast<StringAttr>()
77 .getValue();
78 if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
79 loc, tpu_device, &device_ordinal))) {
80 return llvm::Optional<int64_t>(device_ordinal);
81 }
82 }
83 }
84 return llvm::None;
85 }
86
87 // Updates replica variant ops in a region based on replica `replica_id`.
88 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
89 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
90 // ops, require replica id to be added as an op attribute to be used during
91 // execution. Handle such ops separately and add an integer attribute that
92 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)93 LogicalResult UpdateRegionReplicateVariantOps(
94 OpBuilder& builder, Location loc, Region& region, int replica_id,
95 const llvm::Optional<DictionaryAttr>& devices) {
96 llvm::Optional<int64_t> device_ordinal =
97 GetDeviceOrdinal(devices, loc, replica_id);
98
99 auto result = region.walk([&](Operation* op) -> WalkResult {
100 if (RequiresReplicaIDAttribute(op)) {
101 op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
102 return WalkResult::advance();
103 }
104
105 if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
106 if (!device_ordinal.has_value())
107 return op->emitOpError()
108 << "requires device ordinal from device " << kTPUCore0
109 << " to be present in 'tf.device.replicate' op";
110
111 OpBuilder builder(op);
112 auto const_op = builder.create<TF::ConstOp>(
113 op->getLoc(), DenseIntElementsAttr::get(
114 RankedTensorType::get({}, builder.getI64Type()),
115 {device_ordinal.getValue()}));
116 op->replaceAllUsesWith(const_op);
117 op->erase();
118 return WalkResult::advance();
119 }
120
121 if (!devices.has_value()) return WalkResult::advance();
122
123 // Map aliased devices to explicit devices based on replica.
124 if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
125 if (auto device_by_replica = devices.getValue().get(launch.device()))
126 launch->setAttr(
127 kDeviceAttr,
128 device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
129
130 return WalkResult::advance();
131 });
132
133 return failure(result.wasInterrupted());
134 }
135
136 // Creates islands per replica from `tf_device.replicate` region. If for a
137 // `tf_device.launch` op the device is an aliased device of the
138 // `tf_device.replicate`, the device will be remapped to an explicit device
139 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)140 LogicalResult ExpandReplicateIntoReplicas(
141 const Dialect* tf_dialect, OpBuilder& builder,
142 tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
143 int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
144 replicas.reserve(num_replicas);
145 auto devices = replicate_op.devices();
146
147 // Collect result types and operands.
148 Operation& terminator = replicate_op.GetBody().back();
149 llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
150 auto control_type = tf_executor::ControlType::get(island_op.getContext());
151 llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
152
153 // Replace replicate terminator with YieldOp.
154 builder.setInsertionPoint(&terminator);
155 builder.create<tf_executor::YieldOp>(terminator.getLoc(),
156 terminator.getOperands());
157 terminator.erase();
158
159 builder.setInsertionPoint(island_op);
160 BlockAndValueMapping mapping;
161 for (int i : llvm::seq<int>(0, num_replicas)) {
162 // Create new island for replica.
163 auto replica = builder.create<tf_executor::IslandOp>(
164 island_op.getLoc(), output_types, control_type, replica_inputs);
165
166 // Map block arg to replica arg.
167 mapping.clear();
168 for (auto& block_arg : replicate_op.GetBody().getArguments())
169 mapping.map(block_arg,
170 replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
171
172 // Copy over replicate region into replica island.
173 replicate_op.body().cloneInto(&replica.body(), mapping);
174
175 if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
176 replica.body(),
177 /*replica_id=*/i, devices)))
178 return failure();
179
180 replicas.push_back(replica);
181 }
182
183 return success();
184 }
185
186 // Creates islands per replica from `tf_device.replicate` region and remap
187 // replicate results with new island outputs. A single island is created to
188 // forward control dependencies if there is a control dependency output from the
189 // replicate island. Devices are remapped from aliased devices to explicit
190 // devices, for `tf_device.launch` ops.
191 //
192 // For example, the following:
193 //
194 // %0:2 = tf_executor.island(%control) {
195 // %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
196 // {n = 2 : i32,
197 // devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
198 // DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
199 // %a = "tf_device.launch"() ({
200 // %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
201 // tf_device.return %2 : tensor<i1>
202 // }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
203 // %b = "tf_device.launch"() ({
204 // %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
205 // tf_device.return %3 : tensor<i1>
206 // }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
207 // tf_device.return %a, %b : tensor<i1>, tensor<i1>
208 // }
209 // tf_executor.yield %1#0 : tensor<i1>
210 // }
211 //
212 // gets lowered to:
213 //
214 // %0:3 = tf_executor.island(%control) {
215 // %a0 = "tf_device.launch"() ({
216 // %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
217 // tf_device.return %1 : tensor<i1>
218 // }) {device = "/DEVICE:0"} : () -> tensor<i1>
219 // %b0 = "tf_device.launch"() ({
220 // %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
221 // tf_device.return %2 : tensor<i1>
222 // }) {device = "/DEVICE:2"} : () -> tensor<i1>
223 // tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
224 // }
225 // %3:3 = tf_executor.island(%control) {
226 // %a1 = "tf_device.launch"() ({
227 // %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
228 // tf_device.return %4 : tensor<i1>
229 // }) {device = "/DEVICE:1"} : () -> tensor<i1>
230 // %b1 = "tf_device.launch"() ({
231 // %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
232 // tf_device.return %5 : tensor<i1>
233 // }) {device = "/DEVICE:3"} : () -> tensor<i1>
234 // tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
235 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)236 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
237 tf_executor::GraphOp graph_op,
238 tf_executor::IslandOp island_op,
239 tf_device::ReplicateOp replicate_op) {
240 OpBuilder builder(island_op);
241 const int num_replicas = replicate_op.n();
242
243 // Create islands per replica.
244 llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
245 if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
246 replicate_op, num_replicas, replicas)))
247 return failure();
248
249 // Collect all replica results.
250 llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
251 nullptr);
252 for (auto replica_and_idx : llvm::enumerate(replicas))
253 for (auto replica_result_and_idx :
254 llvm::enumerate(replica_and_idx.value().outputs()))
255 replicas_outputs[num_replicas * replica_result_and_idx.index() +
256 replica_and_idx.index()] =
257 replica_result_and_idx.value();
258
259 // Remap replicate results to per replica result.
260 for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
261 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
262
263 // Add sink island to pin all replicas as a control dependency if there is a
264 // control dependency leading from the replicate originally.
265 if (!island_op.control().use_empty()) {
266 llvm::SmallVector<Value, 8> island_operands;
267 for (auto& replica : replicas) island_operands.push_back(replica.control());
268
269 builder.setInsertionPoint(island_op);
270 auto island_sink = builder.create<tf_executor::IslandOp>(
271 island_op.getLoc(), llvm::ArrayRef<Type>{},
272 tf_executor::ControlType::get(island_op.getContext()), island_operands);
273 island_sink.body().push_back(new Block);
274 builder.setInsertionPointToEnd(&island_sink.GetBody());
275 builder.create<tf_executor::YieldOp>(island_op.getLoc(),
276 llvm::ArrayRef<Value>{});
277 island_op.control().replaceAllUsesWith(island_sink.control());
278 }
279
280 // Replicas with no uses should be pinned to a graph fetch so they still
281 // execute.
282 llvm::SmallVector<Value, 8> unused_replica_controls;
283 for (auto& replica : replicas)
284 if (replica.use_empty())
285 unused_replica_controls.push_back(replica.control());
286
287 if (!unused_replica_controls.empty()) {
288 tf_executor::FetchOp fetch = graph_op.GetFetch();
289 auto fetches = llvm::to_vector<8>(fetch.getOperands());
290 fetches.append(unused_replica_controls.begin(),
291 unused_replica_controls.end());
292 builder.setInsertionPoint(fetch);
293 builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
294 fetch.erase();
295 }
296
297 island_op.erase();
298 return success();
299 }
300
runOnOperation()301 void ReplicateToIslandPass::runOnOperation() {
302 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
303 if (!tf_dialect) {
304 getOperation().emitError() << "'tf' dialect is not registered";
305 return signalPassFailure();
306 }
307
308 // Find islands with a single `tf_device.replicate` and create individual
309 // islands per replica of the replicate.
310 llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
311 getOperation().walk([&](tf_executor::GraphOp graph_op) {
312 for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
313 if (!island_op.WrapsSingleOp()) continue;
314
315 if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
316 replicate_op_islands.push_back(island_op);
317 }
318 });
319
320 for (tf_executor::IslandOp island_op : replicate_op_islands) {
321 auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
322 auto replicate_op =
323 cast<tf_device::ReplicateOp>(island_op.GetBody().front());
324 if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
325 replicate_op)))
326 return signalPassFailure();
327 }
328 }
329 } // anonymous namespace
330
CreateReplicateToIslandPass()331 std::unique_ptr<OperationPass<func::FuncOp>> CreateReplicateToIslandPass() {
332 return std::make_unique<ReplicateToIslandPass>();
333 }
334
335 } // namespace TFDevice
336 } // namespace mlir
337