xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass forms `tf_executor.island` per replica from a single
17 // `tf_device.replicate` island.
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Block.h"  // from @llvm-project
32 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
35 #include "mlir/IR/Dialect.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Pass/Pass.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45 
46 namespace mlir {
47 namespace TFDevice {
48 namespace {
49 constexpr char kDeviceAttr[] = "device";
50 constexpr char kReplicaIdAttr[] = "_xla_replica_id";
51 constexpr char kDeviceOrdinalAttr[] = "device_ordinal";
52 constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0";
53 
54 struct ReplicateToIslandPass
55     : public TF::ReplicateToIslandPassBase<ReplicateToIslandPass> {
56   void runOnOperation() override;
57 };
58 
59 // Returns whether op requires `_xla_replica_id` attribute.
RequiresReplicaIDAttribute(Operation * op)60 bool RequiresReplicaIDAttribute(Operation* op) {
61   return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
62                    TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
63                    TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp>(op);
64 }
65 
66 // Collects TPU device ordinal for outside compilation communication ops. This
67 // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0`
68 // aliased device for the device computation.
GetDeviceOrdinal(const llvm::Optional<DictionaryAttr> & devices,Location loc,unsigned replica_id)69 llvm::Optional<int64_t> GetDeviceOrdinal(
70     const llvm::Optional<DictionaryAttr>& devices, Location loc,
71     unsigned replica_id) {
72   int64_t device_ordinal = 0;
73   if (devices.has_value()) {
74     if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) {
75       llvm::StringRef tpu_device = tpu_replica_0.cast<ArrayAttr>()[replica_id]
76                                        .cast<StringAttr>()
77                                        .getValue();
78       if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString(
79               loc, tpu_device, &device_ordinal))) {
80         return llvm::Optional<int64_t>(device_ordinal);
81       }
82     }
83   }
84   return llvm::None;
85 }
86 
87 // Updates replica variant ops in a region based on replica `replica_id`.
88 // TODO(b/157624749): Replace this with better abstraction to differentiate ops
89 // for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding
90 // ops, require replica id to be added as an op attribute to be used during
91 // execution. Handle such ops separately and add an integer attribute that
92 // represents replica id.
UpdateRegionReplicateVariantOps(OpBuilder & builder,Location loc,Region & region,int replica_id,const llvm::Optional<DictionaryAttr> & devices)93 LogicalResult UpdateRegionReplicateVariantOps(
94     OpBuilder& builder, Location loc, Region& region, int replica_id,
95     const llvm::Optional<DictionaryAttr>& devices) {
96   llvm::Optional<int64_t> device_ordinal =
97       GetDeviceOrdinal(devices, loc, replica_id);
98 
99   auto result = region.walk([&](Operation* op) -> WalkResult {
100     if (RequiresReplicaIDAttribute(op)) {
101       op->setAttr(kReplicaIdAttr, builder.getI64IntegerAttr(replica_id));
102       return WalkResult::advance();
103     }
104 
105     if (isa<TF::_TPUDeviceOrdinalPlaceholderOp>(op)) {
106       if (!device_ordinal.has_value())
107         return op->emitOpError()
108                << "requires device ordinal from device " << kTPUCore0
109                << " to be present in 'tf.device.replicate' op";
110 
111       OpBuilder builder(op);
112       auto const_op = builder.create<TF::ConstOp>(
113           op->getLoc(), DenseIntElementsAttr::get(
114                             RankedTensorType::get({}, builder.getI64Type()),
115                             {device_ordinal.getValue()}));
116       op->replaceAllUsesWith(const_op);
117       op->erase();
118       return WalkResult::advance();
119     }
120 
121     if (!devices.has_value()) return WalkResult::advance();
122 
123     // Map aliased devices to explicit devices based on replica.
124     if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
125       if (auto device_by_replica = devices.getValue().get(launch.device()))
126         launch->setAttr(
127             kDeviceAttr,
128             device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
129 
130     return WalkResult::advance();
131   });
132 
133   return failure(result.wasInterrupted());
134 }
135 
136 // Creates islands per replica from `tf_device.replicate` region. If for a
137 // `tf_device.launch` op the device is an aliased device of the
138 // `tf_device.replicate`, the device will be remapped to an explicit device
139 // for the associated replica island.
ExpandReplicateIntoReplicas(const Dialect * tf_dialect,OpBuilder & builder,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op,int num_replicas,llvm::SmallVectorImpl<tf_executor::IslandOp> & replicas)140 LogicalResult ExpandReplicateIntoReplicas(
141     const Dialect* tf_dialect, OpBuilder& builder,
142     tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
143     int num_replicas, llvm::SmallVectorImpl<tf_executor::IslandOp>& replicas) {
144   replicas.reserve(num_replicas);
145   auto devices = replicate_op.devices();
146 
147   // Collect result types and operands.
148   Operation& terminator = replicate_op.GetBody().back();
149   llvm::SmallVector<Type, 8> output_types(terminator.getOperandTypes());
150   auto control_type = tf_executor::ControlType::get(island_op.getContext());
151   llvm::SmallVector<Value, 8> replica_inputs(island_op.controlInputs());
152 
153   // Replace replicate terminator with YieldOp.
154   builder.setInsertionPoint(&terminator);
155   builder.create<tf_executor::YieldOp>(terminator.getLoc(),
156                                        terminator.getOperands());
157   terminator.erase();
158 
159   builder.setInsertionPoint(island_op);
160   BlockAndValueMapping mapping;
161   for (int i : llvm::seq<int>(0, num_replicas)) {
162     // Create new island for replica.
163     auto replica = builder.create<tf_executor::IslandOp>(
164         island_op.getLoc(), output_types, control_type, replica_inputs);
165 
166     // Map block arg to replica arg.
167     mapping.clear();
168     for (auto& block_arg : replicate_op.GetBody().getArguments())
169       mapping.map(block_arg,
170                   replicate_op.GetReplicaOperandForBlockArgument(block_arg, i));
171 
172     // Copy over replicate region into replica island.
173     replicate_op.body().cloneInto(&replica.body(), mapping);
174 
175     if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(),
176                                                replica.body(),
177                                                /*replica_id=*/i, devices)))
178       return failure();
179 
180     replicas.push_back(replica);
181   }
182 
183   return success();
184 }
185 
186 // Creates islands per replica from `tf_device.replicate` region and remap
187 // replicate results with new island outputs. A single island is created to
188 // forward control dependencies if there is a control dependency output from the
189 // replicate island. Devices are remapped from aliased devices to explicit
190 // devices, for `tf_device.launch` ops.
191 //
192 // For example, the following:
193 //
194 // %0:2 = tf_executor.island(%control) {
195 //   %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
196 //              {n = 2 : i32,
197 //               devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
198 //                          DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
199 //     %a = "tf_device.launch"() ({
200 //       %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
201 //       tf_device.return %2 : tensor<i1>
202 //     }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
203 //     %b = "tf_device.launch"() ({
204 //       %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
205 //       tf_device.return %3 : tensor<i1>
206 //     }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
207 //     tf_device.return %a, %b : tensor<i1>, tensor<i1>
208 //   }
209 //   tf_executor.yield %1#0 : tensor<i1>
210 // }
211 //
212 // gets lowered to:
213 //
214 // %0:3 = tf_executor.island(%control) {
215 //   %a0 = "tf_device.launch"() ({
216 //     %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
217 //     tf_device.return %1 : tensor<i1>
218 //   }) {device = "/DEVICE:0"} : () -> tensor<i1>
219 //   %b0 = "tf_device.launch"() ({
220 //     %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
221 //     tf_device.return %2 : tensor<i1>
222 //   }) {device = "/DEVICE:2"} : () -> tensor<i1>
223 //   tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
224 // }
225 // %3:3 = tf_executor.island(%control) {
226 //   %a1 = "tf_device.launch"() ({
227 //     %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
228 //     tf_device.return %4 : tensor<i1>
229 //   }) {device = "/DEVICE:1"} : () -> tensor<i1>
230 //   %b1 = "tf_device.launch"() ({
231 //     %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
232 //     tf_device.return %5 : tensor<i1>
233 //   }) {device = "/DEVICE:3"} : () -> tensor<i1>
234 //   tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
235 // }
CreateIslandsFromReplicate(const Dialect * tf_dialect,tf_executor::GraphOp graph_op,tf_executor::IslandOp island_op,tf_device::ReplicateOp replicate_op)236 LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect,
237                                          tf_executor::GraphOp graph_op,
238                                          tf_executor::IslandOp island_op,
239                                          tf_device::ReplicateOp replicate_op) {
240   OpBuilder builder(island_op);
241   const int num_replicas = replicate_op.n();
242 
243   // Create islands per replica.
244   llvm::SmallVector<tf_executor::IslandOp, 8> replicas;
245   if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op,
246                                          replicate_op, num_replicas, replicas)))
247     return failure();
248 
249   // Collect all replica results.
250   llvm::SmallVector<Value, 8> replicas_outputs(replicate_op.getNumResults(),
251                                                nullptr);
252   for (auto replica_and_idx : llvm::enumerate(replicas))
253     for (auto replica_result_and_idx :
254          llvm::enumerate(replica_and_idx.value().outputs()))
255       replicas_outputs[num_replicas * replica_result_and_idx.index() +
256                        replica_and_idx.index()] =
257           replica_result_and_idx.value();
258 
259   // Remap replicate results to per replica result.
260   for (auto result : llvm::zip(island_op.outputs(), replicas_outputs))
261     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
262 
263   // Add sink island to pin all replicas as a control dependency if there is a
264   // control dependency leading from the replicate originally.
265   if (!island_op.control().use_empty()) {
266     llvm::SmallVector<Value, 8> island_operands;
267     for (auto& replica : replicas) island_operands.push_back(replica.control());
268 
269     builder.setInsertionPoint(island_op);
270     auto island_sink = builder.create<tf_executor::IslandOp>(
271         island_op.getLoc(), llvm::ArrayRef<Type>{},
272         tf_executor::ControlType::get(island_op.getContext()), island_operands);
273     island_sink.body().push_back(new Block);
274     builder.setInsertionPointToEnd(&island_sink.GetBody());
275     builder.create<tf_executor::YieldOp>(island_op.getLoc(),
276                                          llvm::ArrayRef<Value>{});
277     island_op.control().replaceAllUsesWith(island_sink.control());
278   }
279 
280   // Replicas with no uses should be pinned to a graph fetch so they still
281   // execute.
282   llvm::SmallVector<Value, 8> unused_replica_controls;
283   for (auto& replica : replicas)
284     if (replica.use_empty())
285       unused_replica_controls.push_back(replica.control());
286 
287   if (!unused_replica_controls.empty()) {
288     tf_executor::FetchOp fetch = graph_op.GetFetch();
289     auto fetches = llvm::to_vector<8>(fetch.getOperands());
290     fetches.append(unused_replica_controls.begin(),
291                    unused_replica_controls.end());
292     builder.setInsertionPoint(fetch);
293     builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
294     fetch.erase();
295   }
296 
297   island_op.erase();
298   return success();
299 }
300 
runOnOperation()301 void ReplicateToIslandPass::runOnOperation() {
302   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
303   if (!tf_dialect) {
304     getOperation().emitError() << "'tf' dialect is not registered";
305     return signalPassFailure();
306   }
307 
308   // Find islands with a single `tf_device.replicate` and create individual
309   // islands per replica of the replicate.
310   llvm::SmallVector<tf_executor::IslandOp, 4> replicate_op_islands;
311   getOperation().walk([&](tf_executor::GraphOp graph_op) {
312     for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
313       if (!island_op.WrapsSingleOp()) continue;
314 
315       if (isa<tf_device::ReplicateOp>(&island_op.GetBody().front()))
316         replicate_op_islands.push_back(island_op);
317     }
318   });
319 
320   for (tf_executor::IslandOp island_op : replicate_op_islands) {
321     auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
322     auto replicate_op =
323         cast<tf_device::ReplicateOp>(island_op.GetBody().front());
324     if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
325                                           replicate_op)))
326       return signalPassFailure();
327   }
328 }
329 }  // anonymous namespace
330 
CreateReplicateToIslandPass()331 std::unique_ptr<OperationPass<func::FuncOp>> CreateReplicateToIslandPass() {
332   return std::make_unique<ReplicateToIslandPass>();
333 }
334 
335 }  // namespace TFDevice
336 }  // namespace mlir
337