1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <atomic>
17 #include <string>
18
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/Types.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/Pass/Pass.h" // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/str_util.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/cc/dstatus.h"
46 #include "tensorflow/dtensor/cc/dtensor_utils.h"
47 #include "tensorflow/dtensor/mlir/collectives_common.h"
48 #include "tensorflow/dtensor/mlir/device_utils.h"
49 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
50 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
51 #include "tensorflow/dtensor/mlir/dtensor_location.h"
52 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
53 #include "tensorflow/dtensor/mlir/group_assignment.h"
54 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
55 #include "tensorflow/dtensor/mlir/layout_parsing.h"
56 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
57 #include "tensorflow/dtensor/mlir/value_utils.h"
58
59 namespace tensorflow {
60 namespace dtensor {
61
62 namespace {
63
64 namespace ops_util = ::mlir::TF::collection_ops_util;
65 constexpr int32 kUninitializedGroupKey = 0;
66
67 // A counter that is used to generate shift base values for TF collective group
68 // and instance keys. Every TF collective AllReduce op in a program gets a value
69 // from this counter. The value increments according to the position of the
70 // AllReduce op in the program. Different hosts go through exactly the same MLIR
71 // logic and therefore iterate over AllReduce ops in the same order (even in the
72 // presence of control flow), so they should indenpendently generate the same
73 // counter value for matching AllReduce ops across hosts.
74 static std::atomic<int32> tf_collective_key_base{0};
75
76 } // namespace
77 } // namespace dtensor
78 } // namespace tensorflow
79
80 #ifdef PLATFORM_GOOGLE
81 // Use the Google internal version of EmitAllReduceForXla.
82 #include "collective_lowering_google.inc"
83 #else
84 namespace tensorflow {
85 namespace dtensor {
86 namespace {
87 constexpr char kCrossReplica[] = "CrossReplica";
88
EmitAllReduceForXla(mlir::MLIRContext & context,mlir::OpBuilder & builder,mlir::TF::DTensorAllReduceOp all_reduce,mlir::DenseIntElementsAttr group_assignment_attr,int32 key_base,mlir::Operation ** final_op)89 mlir::LogicalResult EmitAllReduceForXla(
90 mlir::MLIRContext& context, mlir::OpBuilder& builder,
91 mlir::TF::DTensorAllReduceOp all_reduce,
92 mlir::DenseIntElementsAttr group_assignment_attr, int32 key_base,
93 mlir::Operation** final_op) {
94 // For TPUs, lower to XlaAllReduce straightforwardly.
95 *final_op = builder.create<mlir::TF::XlaAllReduceOp>(
96 all_reduce.getLoc(), all_reduce.getResult().getType(), all_reduce.input(),
97 all_reduce.group_assignment(), all_reduce.reduce_opAttr(),
98 builder.getStringAttr(kCrossReplica));
99 return mlir::success();
100 }
101 } // namespace
102 } // namespace dtensor
103 } // namespace tensorflow
104 #endif
105
106 namespace tensorflow {
107 namespace dtensor {
108 namespace {
109 // Emit a host CollectiveReduce op for the given input.
110 // `group_assignment` is used to generate an array of group keys.
111 // `device_id` slices into that array to get the key for a device at runtime.
112 // `key_base` is the common part shared by all group keys.
113 // `device_id` is an mlir::Value that will contain the device ID at runtime.
114 // `host_group_size` sets host collective group size. It should match the number
115 // of active devices running the host collective and supplying device IDs,
116 // else the host collective will crash or hang.
EmitCollectiveReduce(mlir::OpBuilder & builder,const mlir::Location & loc,mlir::Value input,const std::string & reduce_op_str,const mlir::DenseIntElementsAttr & group_assignment,int32 key_base,mlir::Value device_id,int32 host_group_size,const mlir::StringRef device_type)117 mlir::Operation* EmitCollectiveReduce(
118 mlir::OpBuilder& builder, const mlir::Location& loc, mlir::Value input,
119 const std::string& reduce_op_str,
120 const mlir::DenseIntElementsAttr& group_assignment, int32 key_base,
121 mlir::Value device_id, int32 host_group_size,
122 const mlir::StringRef device_type) {
123 DCHECK_EQ(group_assignment.getType().getRank(), 2);
124 auto shape = group_assignment.getType().getShape();
125 const int32 num_groups = shape[0];
126 const int32 group_size = shape[1];
127 const int32 num_devices = num_groups * group_size;
128 const mlir::TensorType input_type =
129 input.getType().dyn_cast<mlir::TensorType>();
130
131 const bool need_int32_to_int64_upcast =
132 (device_type.endswith("GPU") && input_type &&
133 input_type.getElementType().isInteger(32));
134
135 if (need_int32_to_int64_upcast) {
136 LOG(WARNING) << "On GPU, collective reduce of int32 is not supported. "
137 "Casting to int64 as a workaround: "
138 << mlir::debugString(loc);
139
140 mlir::TF::CastOp cast_to_int64 = builder.create<mlir::TF::CastOp>(
141 loc,
142 mlir::RankedTensorType::get(input_type.getShape(),
143 builder.getIntegerType(64)),
144 input);
145 input = cast_to_int64.getResult();
146 }
147 mlir::Value group_key_scalar;
148 llvm::SmallVector<int32, 4> device_id_to_group_key(num_devices);
149 device_id_to_group_key.resize(num_devices, kUninitializedGroupKey);
150 // 21 bits + 11 bits allow roughly 2M all-reduces in one program and up to a
151 // full DF pod.
152 DCHECK_LT(key_base, 1L << 21) << "Reaching 2^21 all-reduces.";
153 DCHECK_LE(num_devices, 1L << 11) << "Exceeding 2048 groups.";
154 for (const auto& it :
155 llvm::enumerate(group_assignment.getValues<llvm::APInt>())) {
156 int32 device_id = it.value().getSExtValue();
157 DCHECK_LE(0, device_id);
158 DCHECK_LT(device_id, num_devices);
159 DCHECK_EQ(device_id_to_group_key[device_id], kUninitializedGroupKey);
160 const int32 group_id = static_cast<int32>(it.index()) / group_size;
161 device_id_to_group_key[device_id] = (key_base << 11) ^ group_id;
162 }
163
164 // Create a scalar group key by slicing device_id_to_group_key with
165 // device_id.
166 auto group_key_slice = builder.create<mlir::TF::SliceOp>(
167 loc, EffectivelyScalarR1Type(builder.getIntegerType(32)),
168 /*input=*/IntConst(builder, loc, device_id_to_group_key),
169 /*begin=*/device_id,
170 /*size=*/IntConst(builder, loc, {1}));
171 auto group_key_reshape = builder.create<mlir::TF::ReshapeOp>(
172 loc, /*tensor=*/group_key_slice.getResult(),
173 /*shape=*/ops_util::GetR1Const({}, builder, loc));
174 group_key_scalar = group_key_reshape.getResult();
175
176 // Generate a unique instance key for this collective.
177 mlir::Value instance_key_scalar =
178 ops_util::CreateScalarConst(static_cast<int32>(key_base), builder, loc);
179
180 const bool is_mean_op = reduce_op_str == kReduceOpMean;
181 mlir::Value group_size_scalar =
182 ops_util::CreateScalarConst(host_group_size, builder, loc);
183 auto collective_reduce = builder.create<mlir::TF::CollectiveReduceV2Op>(
184 loc, /*output_type=*/input.getType(), input, group_size_scalar,
185 group_key_scalar, instance_key_scalar,
186 /*ordering_token=*/mlir::ValueRange({}),
187 /*merge_op=*/builder.getStringAttr(is_mean_op ? "Add" : reduce_op_str),
188 /*final_op=*/builder.getStringAttr(is_mean_op ? "Div" : "Id"),
189 /*communication_hint=*/builder.getStringAttr(""),
190 /*timeout_seconds=*/builder.getF32FloatAttr(0.),
191 /*max_subdivs_per_device=*/builder.getI64IntegerAttr(16));
192 SetSingleLayoutOnOp(collective_reduce, Layout::Empty());
193 if (need_int32_to_int64_upcast) {
194 return builder.create<mlir::TF::CastOp>(
195 loc,
196 mlir::RankedTensorType::get(input_type.getShape(),
197 builder.getIntegerType(32)),
198 collective_reduce);
199 }
200 return collective_reduce;
201 }
202
LowerAllReduceOpImpl(mlir::MLIRContext & context,mlir::OpBuilder & builder,mlir::TF::DTensorAllReduceOp all_reduce,mlir::Value * value)203 mlir::LogicalResult LowerAllReduceOpImpl(
204 mlir::MLIRContext& context, mlir::OpBuilder& builder,
205 mlir::TF::DTensorAllReduceOp all_reduce, mlir::Value* value) {
206 mlir::Location loc = all_reduce.getLoc();
207 StatusOr<Layout> output_layout =
208 ExtractRequiredSingleLayoutFromOp(all_reduce);
209 if (!output_layout.ok()) {
210 return all_reduce.emitOpError(output_layout.status().error_message());
211 }
212 mlir::DenseIntElementsAttr group_assignment_attr;
213 if (!matchPattern(all_reduce.group_assignment(),
214 m_Constant(&group_assignment_attr)))
215 return mlir::emitError(loc, "group_assigment must be a constant.");
216 if (group_assignment_attr.getType().getRank() != 2)
217 return mlir::emitError(loc, "group_assignment should have two dimensions.");
218 int32 group_size = group_assignment_attr.getType().getShape()[1];
219
220 // This will become more general when Topology is properly defined.
221 const bool is_tpu = all_reduce.device_type().endswith("TPU");
222 // Use an atomic counter to generate bases for group and instance keys.
223 int32 key_base = tf_collective_key_base++;
224
225 mlir::Operation* final_op;
226 if (is_tpu) {
227 if (mlir::failed(EmitAllReduceForXla(context, builder, all_reduce,
228 group_assignment_attr, key_base,
229 &final_op))) {
230 return mlir::failure();
231 }
232 } else {
233 // Generate CPU/GPU collective. CPU/GPU collectives identify groups on
234 // the basis of a local group key. We must generate an appropriate group
235 // key based on our device ID. This is expressible as an algebraic
236 // function of the device id, but we instead encode the
237 // device_id->group_key as an explicit map value and lookup the result
238 // at runtime. Note that the order we map devices to partitions is not
239 // deterministic, and moreover if we have multiple distinct reductions
240 // groups in one program reducing over all hosts and reducing over pairs
241 // of hosts, we need unique ids for each case.
242 mlir::Value device_id = ops_util::ReshapeScalarToSizeType(
243 builder, DeviceId(all_reduce.getResult()).ValueOrDie(), loc);
244 // TODO(b/188076080): Clean up device id.
245 mlir::Value start_device_id = ops_util::GetR1Const(
246 {(*output_layout).mesh().min_global_device_id()}, builder, loc);
247 mlir::Value relative_device_id =
248 builder.create<mlir::TF::SubOp>(loc, device_id, start_device_id);
249
250 final_op = EmitCollectiveReduce(
251 builder, loc, all_reduce.input(), all_reduce.reduce_op().str(),
252 group_assignment_attr, key_base, relative_device_id,
253 /*host_group_size=*/group_size, all_reduce.device_type().str());
254 }
255 SetSingleLayoutOnOp(final_op, *output_layout);
256 *value = final_op->getResult(0);
257 return mlir::success();
258 }
259
260 template <class ReduceOpType>
ConvertBoolReduce(ReduceOpType reduce_op)261 mlir::LogicalResult ConvertBoolReduce(ReduceOpType reduce_op) {
262 mlir::OpBuilder builder(reduce_op);
263 const mlir::Location loc = reduce_op.getLoc();
264 const mlir::Type output_type = reduce_op.getResult().getType();
265 const mlir::Type input_type = reduce_op.getOperand(0).getType();
266
267 // Handle bools by first casting to int32 and swapping All/Any for Min/Max.
268 const mlir::TensorType& tensor_input_type =
269 input_type.dyn_cast<mlir::TensorType>();
270 const mlir::TensorType& tensor_output_type =
271 output_type.dyn_cast<mlir::TensorType>();
272 if (tensor_input_type && tensor_output_type &&
273 tensor_input_type.getElementType().isInteger(1)) {
274 if (reduce_op.reduce_opAttr().getValue().str() == kReduceOpAll)
275 reduce_op.reduce_opAttr(builder.getStringAttr(std::string(kReduceOpMin)));
276 else if (reduce_op.reduce_opAttr().getValue().str() == kReduceOpAny)
277 reduce_op.reduce_opAttr(builder.getStringAttr(std::string(kReduceOpMax)));
278 else
279 return reduce_op.emitOpError()
280 << "reduce for boolean only supports 'All' or 'Any' reduction. "
281 << "Received '" << reduce_op.reduce_opAttr().getValue().str()
282 << "'";
283 const mlir::Type integer_input_type = mlir::RankedTensorType::get(
284 tensor_input_type.getShape(), builder.getIntegerType(32));
285 mlir::TF::CastOp cast_to_int32 = builder.create<mlir::TF::CastOp>(
286 loc, integer_input_type, reduce_op.input());
287 reduce_op.setOperand(0, cast_to_int32.y());
288 const mlir::Type integer_output_type = mlir::RankedTensorType::get(
289 tensor_output_type.getShape(), builder.getIntegerType(32));
290 reduce_op.output().setType(integer_output_type);
291
292 // Add cast back to boolean after reduction.
293 mlir::Value result = reduce_op.output();
294 builder.setInsertionPointAfter(reduce_op);
295 mlir::TF::CastOp cast_to_bool =
296 builder.create<mlir::TF::CastOp>(loc, output_type, result);
297 StatusOr<Layout> result_layout =
298 ExtractRequiredSingleLayoutFromOp(result.getDefiningOp());
299 if (!result_layout.ok()) {
300 return reduce_op.emitOpError(result_layout.status().error_message());
301 }
302 SetSingleLayoutOnOp(cast_to_bool, *result_layout);
303 reduce_op.output().replaceAllUsesExcept(cast_to_bool.y(), cast_to_bool);
304 }
305
306 return mlir::success();
307 }
308
LowerAllReduceOp(mlir::MLIRContext & context,mlir::TF::DTensorAllReduceOp all_reduce)309 mlir::LogicalResult LowerAllReduceOp(mlir::MLIRContext& context,
310 mlir::TF::DTensorAllReduceOp all_reduce) {
311 if (mlir::failed(ConvertBoolReduce<mlir::TF::DTensorAllReduceOp>(all_reduce)))
312 return mlir::failure();
313
314 mlir::OpBuilder builder(all_reduce);
315 mlir::Value result;
316 if (mlir::failed(LowerAllReduceOpImpl(context, builder, all_reduce, &result)))
317 return mlir::failure();
318
319 all_reduce.replaceAllUsesWith(result);
320 all_reduce.erase();
321 return mlir::success();
322 }
323
LowerReduceScatterOp(mlir::TF::DTensorReduceScatterOp reduce_scatter)324 mlir::LogicalResult LowerReduceScatterOp(
325 mlir::TF::DTensorReduceScatterOp reduce_scatter) {
326 mlir::Location loc = reduce_scatter.getLoc();
327
328 StatusOr<Layout> output_layout =
329 ExtractRequiredSingleLayoutFromOp(reduce_scatter);
330 if (!output_layout.ok()) {
331 return reduce_scatter.emitOpError(output_layout.status().error_message());
332 }
333 mlir::DenseIntElementsAttr group_assignment_attr;
334 if (!matchPattern(reduce_scatter.group_assignment(),
335 m_Constant(&group_assignment_attr)))
336 return reduce_scatter.emitOpError("group_assigment must be a constant.");
337 if (group_assignment_attr.getType().getRank() != 2)
338 return reduce_scatter.emitOpError(
339 "group_assignment should have two dimensions.");
340
341 mlir::OpBuilder builder(reduce_scatter);
342 if (reduce_scatter.device_type().endswith("TPU")) {
343 if (mlir::failed(ConvertBoolReduce<mlir::TF::DTensorReduceScatterOp>(
344 reduce_scatter)))
345 return mlir::failure();
346 // For TPUs, lower to XlaReduceScatter straightforwardly.
347 mlir::Operation* xla_reduce_scatter =
348 builder.create<mlir::TF::XlaReduceScatterOp>(
349 loc, reduce_scatter.getResult().getType(), reduce_scatter.input(),
350 reduce_scatter.group_assignment(),
351 reduce_scatter.scatter_dimension(), reduce_scatter.reduce_opAttr());
352 SetSingleLayoutOnOp(xla_reduce_scatter, *output_layout);
353 reduce_scatter.replaceAllUsesWith(xla_reduce_scatter);
354 } else {
355 // For non TPUs device, decompose to DTensorAllReduce+DTensorAllScatter.
356 StatusOr<Layout> input_layout =
357 ExtractRequiredLayoutFromOperand(reduce_scatter.input());
358 if (!input_layout.ok()) {
359 // If input layout is not defined, modify the output_layout based on the
360 // scattered dimension.
361 mlir::DenseIntElementsAttr scatter_attr;
362 if (!matchPattern(reduce_scatter.scatter_dimension(),
363 m_Constant(&scatter_attr))) {
364 return reduce_scatter.emitOpError(
365 "Scatter dimension not constant integer array.");
366 }
367 mlir::APInt scatter_dim = *scatter_attr.begin();
368 std::vector<string> input_sharding_spec =
369 output_layout->sharding_spec_strs();
370 input_sharding_spec[scatter_dim.getSExtValue()] = Layout::kUnshardedDim;
371 input_layout =
372 Layout::GetLayout(input_sharding_spec, output_layout->mesh());
373 }
374
375 if (!input_layout.ok()) {
376 return reduce_scatter.emitOpError(input_layout.status().error_message());
377 }
378
379 auto dtensor_allreduce = builder.create<mlir::TF::DTensorAllReduceOp>(
380 reduce_scatter.getLoc(), reduce_scatter.getOperand(0).getType(),
381 reduce_scatter.getOperand(0), reduce_scatter.group_assignment(),
382 reduce_scatter.reduce_op(), reduce_scatter.device_type());
383 SetSingleLayoutOnOp(dtensor_allreduce, *input_layout);
384
385 mlir::Operation* dtensor_all_scatter =
386 builder.create<mlir::TF::DTensorAllScatterOp>(
387 reduce_scatter.getLoc(), reduce_scatter.getResult().getType(),
388 dtensor_allreduce.getResult(),
389 mlir::dtensor::LayoutAttr::get(builder.getContext(), *input_layout),
390 mlir::dtensor::LayoutAttr::get(builder.getContext(),
391 *output_layout));
392 SetSingleLayoutOnOp(dtensor_all_scatter, *output_layout);
393 reduce_scatter.replaceAllUsesWith(dtensor_all_scatter);
394 }
395 reduce_scatter.erase();
396 return mlir::success();
397 }
398
CreateZeroScalar(mlir::OpBuilder & builder,mlir::Location loc,mlir::RankedTensorType type)399 mlir::Value CreateZeroScalar(mlir::OpBuilder& builder, mlir::Location loc,
400 mlir::RankedTensorType type) {
401 const mlir::Value zero_scalar = ops_util::CreateScalarConst(0, builder, loc);
402 return builder.create<mlir::TF::CastOp>(
403 loc, mlir::RankedTensorType::get({}, type.getElementType()), zero_scalar);
404 }
405
406 // device_id is the relative device_id in a mesh (device id - mesh's 1st device
407 // id).
SelectElementsBasedOnId(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value device_id,const llvm::SmallVectorImpl<int64> & candidates_flat,int64 num_devices,int64 output_shape_size)408 mlir::Value SelectElementsBasedOnId(
409 mlir::OpBuilder& builder, mlir::Location loc, mlir::Value device_id,
410 const llvm::SmallVectorImpl<int64>& candidates_flat, int64 num_devices,
411 int64 output_shape_size) {
412 // Reshape the flat list to a matrix of shape num_devices * output_shape_size.
413 const mlir::Value candidates_flat_const =
414 ops_util::GetR1Const(candidates_flat, builder, loc);
415 const mlir::Value candidates_shape =
416 ops_util::GetR1Const({num_devices, output_shape_size}, builder, loc);
417 const mlir::Value candidates = builder.create<mlir::TF::ReshapeOp>(
418 loc, candidates_flat_const, candidates_shape);
419
420 // Add a zero after the only value in the 1x1 device_id tensor.
421 const mlir::Value device_id_paddings = builder.create<mlir::TF::ReshapeOp>(
422 loc, ops_util::GetR1Const({0, 1}, builder, loc),
423 ops_util::GetR1Const({1, 2}, builder, loc));
424 const mlir::Value device_id_padded = builder.create<mlir::TF::PadOp>(
425 loc, candidates_shape.getType(), /*input=*/device_id,
426 /*paddings=*/device_id_paddings);
427
428 // Slice a vertical vector out of the 2D candidates matrix.
429 const mlir::RankedTensorType chosen_shape_type = mlir::RankedTensorType::get(
430 {1, output_shape_size}, builder.getIntegerType(32));
431 const mlir::Value chosen_shape_const =
432 ops_util::GetR1Const(chosen_shape_type.getShape(), builder, loc);
433 const mlir::Value chosen = builder.create<mlir::TF::SliceOp>(
434 loc, chosen_shape_type, /*input=*/candidates, /*begin=*/device_id_padded,
435 /*size=*/chosen_shape_const);
436
437 // Remove the leading dimension of size 1 before returning the result.
438 return builder.create<mlir::TF::ReshapeOp>(
439 loc, chosen, ops_util::GetR1Const({output_shape_size}, builder, loc));
440 }
441
LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather)442 mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) {
443 const Layout src_layout = all_gather.input_layout();
444 const Layout tgt_layout = all_gather.output_layout();
445
446 llvm::SmallVector<int64, 4> concat_dims;
447 for (int64 i = 0; i < src_layout.rank(); ++i)
448 if (src_layout.num_shards_for_dim(src_layout.dim(i)) > 1 &&
449 Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i)))
450 concat_dims.push_back(i);
451
452 mlir::OpBuilder builder(all_gather);
453 builder.setInsertionPointAfter(all_gather);
454
455 if (concat_dims.empty()) {
456 mlir::TF::IdentityOp identity = builder.create<mlir::TF::IdentityOp>(
457 all_gather.getLoc(), all_gather.input().getType(), all_gather.input());
458 SetSingleLayoutOnOp(identity, tgt_layout);
459
460 all_gather.output().replaceAllUsesWith(identity);
461 all_gather.erase();
462 return mlir::success();
463 }
464
465 const mlir::RankedTensorType input_type =
466 all_gather.input().getType().dyn_cast<mlir::RankedTensorType>();
467 const mlir::RankedTensorType output_type =
468 all_gather.output().getType().dyn_cast<mlir::RankedTensorType>();
469
470 if (!input_type)
471 return all_gather.emitOpError() << "input type is not a RankedTensorType";
472 if (!output_type)
473 return all_gather.emitOpError() << "output type is not a RankedTensorType";
474
475 const std::vector<int64_t> output_shape = output_type.getShape();
476
477 // Construct an output with zeros of the correct size, and add our
478 // local slice into it. We then all reduce to compute a final result.
479 const mlir::Location loc = DT_LOC(all_gather.getLoc());
480 const mlir::Value output_shape_const = Int64Const(builder, loc, output_shape);
481 const mlir::Value zero_scalar = CreateZeroScalar(builder, loc, input_type);
482 const mlir::Value zeros =
483 builder.create<mlir::TF::FillOp>(loc, output_shape_const, zero_scalar);
484
485 // For every possible device ID, generate its strided slice ranges. Store all
486 // ranges---num_devices * output_shape_size * (begin, end, stride)---as three
487 // flat lists.
488 // Consider making this a generalized N-dimensional helper on Layout.
489 const int64 num_devices = src_layout.num_devices();
490 const int64 output_shape_size = output_shape.size();
491 llvm::SmallVector<int64, 4> device_id_to_begin_flat;
492 llvm::SmallVector<int64, 4> device_id_to_end_flat;
493 llvm::SmallVector<int64, 4> device_id_to_strides_flat;
494 for (int64 device_id = 0; device_id < num_devices; ++device_id) {
495 for (int64 i = 0; i < output_shape_size; ++i) {
496 if (llvm::find(concat_dims, i) == std::end(concat_dims)) {
497 // For unsharded dimensions, the slice range is [0, dim_size).
498 device_id_to_begin_flat.push_back(0);
499 device_id_to_end_flat.push_back(output_shape[i]);
500 } else {
501 // For sharded dimensions, the slice range is [step * device_id, step *
502 // (device_id + 1)), where step = dim_size / num_of_shards.
503 StatusOr<DeviceLocation> device_loc_or_status =
504 src_layout.device_location(device_id);
505 if (!device_loc_or_status.ok())
506 return all_gather.emitOpError()
507 << device_loc_or_status.status().error_message();
508 const DeviceLocation device_loc = device_loc_or_status.ValueOrDie();
509 const int32 mesh_idx = src_layout.mesh()
510 .idx_for_dim(src_layout.sharding_spec(i))
511 .ValueOrDie();
512 const int64 device_offset = device_loc[mesh_idx];
513 const int64 step = output_shape[i] / src_layout.num_shards()[i];
514 device_id_to_begin_flat.push_back(step * device_offset);
515 device_id_to_end_flat.push_back(step * device_offset + step);
516 }
517 // We need to change every element in the selected slice, so stride is 1
518 // for every dimension.
519 device_id_to_strides_flat.push_back(1);
520 }
521 }
522
523 // Resize three flat lists to 2D matrices and select one vertical vector out
524 // of every matrix based on device ID.
525 StatusOr<mlir::Value> device_id_scalar_or_status =
526 DeviceId(all_gather.input());
527 if (!device_id_scalar_or_status.ok())
528 return all_gather.emitOpError()
529 << device_id_scalar_or_status.status().error_message();
530 const mlir::Value device_id_scalar = device_id_scalar_or_status.ValueOrDie();
531 const mlir::Value device_id =
532 ops_util::ReshapeScalarToSizeType(builder, device_id_scalar, loc);
533 // TODO(b/188076080): Clean up device id.
534 const mlir::Value start_device_id = ops_util::GetR1Const(
535 {src_layout.mesh().min_global_device_id()}, builder, loc);
536 const mlir::Value relative_device_id =
537 builder.create<mlir::TF::SubOp>(loc, device_id, start_device_id);
538 const mlir::Value begin = SelectElementsBasedOnId(
539 builder, loc, relative_device_id, device_id_to_begin_flat, num_devices,
540 output_shape_size);
541 const mlir::Value end = SelectElementsBasedOnId(
542 builder, loc, relative_device_id, device_id_to_end_flat, num_devices,
543 output_shape_size);
544 const mlir::Value strides = SelectElementsBasedOnId(
545 builder, loc, relative_device_id, device_id_to_strides_flat, num_devices,
546 output_shape_size);
547
548 // Fill in the local portion by slicing into the correct subrange.
549 mlir::Value update_result;
550 if (src_layout.mesh().is_tpu_mesh()) {
551 if (!tgt_layout.mesh().is_tpu_mesh())
552 return all_gather.emitOpError()
553 << "source and target layout are not both on tpu";
554 update_result = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>(
555 loc, zeros.getType(), /*input=*/zeros,
556 /*update=*/all_gather.input(), /*indices=*/begin);
557 } else {
558 update_result = builder.create<mlir::TF::TensorStridedSliceUpdateOp>(
559 loc, zeros.getType(),
560 /*input=*/zeros, begin, end, strides,
561 /*value=*/all_gather.input());
562 }
563
564 // All reduce among concatenated dimensions.
565 absl::flat_hash_set<std::string> reduced_dims;
566 for (int i : concat_dims) reduced_dims.insert(src_layout.sharding_spec(i));
567
568 auto partitions_or_status =
569 GetAllReducePartitionsFromReducedDims(src_layout, reduced_dims);
570 if (!partitions_or_status.ok())
571 return all_gather.emitOpError()
572 << partitions_or_status.status().error_message();
573 auto partitions = partitions_or_status.ValueOrDie();
574 const int32 num_partitions = partitions.size();
575 assert(num_partitions <= num_devices);
576 if (num_partitions == num_devices) {
577 // TODO(unknown): Is this check needed? Since we check that num_shards for
578 // each reduced_dims in the src layout is > 1, I think we always need
579 // communication.
580 // If every device lives in its own partition, we don't need to emit a
581 // collective.
582 SetSingleLayoutOnOp(update_result.getDefiningOp(), tgt_layout);
583 all_gather.output().replaceAllUsesWith(update_result);
584 all_gather.erase();
585 return mlir::success();
586 }
587
588 std::vector<int32> partitions_flat;
589 for (auto& p : partitions) {
590 if (p.second.size() != partitions.begin()->second.size())
591 return all_gather.emitOpError() << "partitions had different sizes -- "
592 "this is not supported in MLIR.";
593 partitions_flat.insert(partitions_flat.end(), p.second.begin(),
594 p.second.end());
595 }
596 const int32 partition_size = partitions.begin()->second.size();
597 const mlir::RankedTensorType shaped_type = mlir::RankedTensorType::get(
598 {num_partitions, partition_size},
599 mlir::IntegerType::get(builder.getContext(), 32));
600 const mlir::DenseIntElementsAttr group_assignment =
601 mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat);
602 StatusOr<std::string> device_type_or_status =
603 DeviceTypeFromMesh(src_layout.mesh());
604 if (!device_type_or_status.ok())
605 return all_gather.emitOpError()
606 << device_type_or_status.status().error_message();
607 const std::string device_type = device_type_or_status.ValueOrDie();
608
609 // Support bool types by switching to Any reduce rather than Add. For each
610 // position in the tensor, only one task in the reduction group can have a 1.
611 // This is sufficient.
612 const mlir::TensorType type =
613 update_result.getType().dyn_cast<mlir::TensorType>();
614 absl::string_view reduce_type = kReduceOpAdd;
615 if (type && type.getElementType().isInteger(1)) reduce_type = kReduceOpAny;
616 mlir::TF::DTensorAllReduceOp all_reduce =
617 builder.create<mlir::TF::DTensorAllReduceOp>(
618 loc, update_result.getType(), update_result,
619 builder.create<mlir::TF::ConstOp>(loc, group_assignment),
620 builder.getStringAttr(std::string(reduce_type)),
621 builder.getStringAttr(device_type));
622 SetSingleLayoutOnOp(all_reduce, tgt_layout);
623
624 all_gather.output().replaceAllUsesWith(all_reduce.getResult());
625 all_gather.erase();
626 return mlir::LogicalResult::success();
627 }
628
LowerAllScatterOp(mlir::TF::DTensorAllScatterOp all_scatter)629 mlir::LogicalResult LowerAllScatterOp(
630 mlir::TF::DTensorAllScatterOp all_scatter) {
631 const Layout original_layout = all_scatter.input_layout();
632 const Layout desired_layout = all_scatter.output_layout();
633
634 mlir::tf_device::ClusterOp cluster =
635 all_scatter->getParentOfType<mlir::tf_device::ClusterOp>();
636 StatusOr<mlir::Value> mesh_coordinates_status =
637 GetMeshCoordinatesFromCluster(cluster);
638 if (!mesh_coordinates_status.ok())
639 return all_scatter.emitOpError()
640 << mesh_coordinates_status.status().error_message();
641 mlir::Value mesh_coordinates = mesh_coordinates_status.ValueOrDie();
642
643 // We need to compute the slice offset, which is dynamic based on the id.
644 //
645 // To compute the offset:
646 // For axes where there is no splitting, the offset is simply 0.
647 // For axes where there is splitting, say axis a, if new local size of that
648 // axis is k, then the offset for the split is
649 // mesh_coordinates[sharding_spec[a]]*k where sharding_spec[i] is the
650 // mesh_dimension for a. This computation can be encoded in small 2d matrix of
651 // shape [mesh.rank(), layout.rank()] where the [i, j]'th entry is k if
652 // sharding_spec[j]=i and this is a dimension with split and 0 otherwise.
653
654 mlir::RankedTensorType output_type =
655 all_scatter.output().getType().dyn_cast<mlir::RankedTensorType>();
656 if (!output_type)
657 return all_scatter.emitOpError() << "input must have static rank";
658
659 llvm::ArrayRef<int64_t> output_shape = output_type.getShape();
660
661 // We use a flat list here. The 2D matrix will be of shape
662 // [original_layout.mesh().rank(), original_layout.rank()]
663 // so the 2D index [i, j] corresponds to the 1D index of
664 // [i * original_layout.rank() + j].
665 std::vector<int32> matrix(original_layout.mesh().rank() *
666 original_layout.rank());
667 for (int i = 0; i < original_layout.rank(); ++i) {
668 if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i)) {
669 if (mlir::ShapedType::isDynamic(output_shape[i])) {
670 return all_scatter.emitOpError()
671 << "EmitAllScatter requires slice on input axis " << i
672 << " which is dynamic. This is not supported";
673 }
674
675 // We already checked above that original_layout.sharding_spec(i) is
676 // unsharded.
677 int mesh_dim_index = desired_layout.mesh().GetMeshDimIndexWithName(
678 desired_layout.sharding_spec(i));
679 matrix[mesh_dim_index * original_layout.rank() + i] = output_shape[i];
680 }
681 }
682
683 // Produce the constant tensor for the slice shape and the matrix.
684
685 mlir::OpBuilder builder(all_scatter);
686
687 // Slice shape has to be int32_t, as it must match the type of the offset to
688 // mlir::TF::SliceOp. The slice offset has to be int32_t as TPU doesn't have
689 // int64_t MatMul (which we use to compute the offset).
690 llvm::SmallVector<int32_t> output_shape_int32(output_shape.begin(),
691 output_shape.end());
692 mlir::Value slice_shape_value =
693 IntConst(builder, all_scatter.getLoc(), output_shape_int32);
694
695 mlir::RankedTensorType matrix_type = mlir::RankedTensorType::get(
696 {original_layout.mesh().rank(), original_layout.rank()},
697 builder.getIntegerType(32));
698 mlir::Attribute matrix_attr =
699 mlir::DenseIntElementsAttr::get(matrix_type, matrix);
700 mlir::Value matrix_value =
701 builder.create<mlir::TF::ConstOp>(all_scatter.getLoc(), matrix_attr)
702 .getResult();
703
704 // Compute the offset from mult_matrix_value and mesh_coordinates.
705 mlir::TF::MatMulOp offset = builder.create<mlir::TF::MatMulOp>(
706 all_scatter.getLoc(),
707 mlir::RankedTensorType::get({1, original_layout.rank()},
708 builder.getIntegerType(32)),
709 mesh_coordinates, matrix_value);
710
711 // Input to slice needs to be rank 1, so we need to sequeeze it.
712 mlir::TF::SqueezeOp offset_squeezed = builder.create<mlir::TF::SqueezeOp>(
713 all_scatter.getLoc(),
714 mlir::RankedTensorType::get({original_layout.rank()},
715 builder.getIntegerType(32)),
716 offset.product(), builder.getI64ArrayAttr({0}));
717
718 auto result = builder.create<mlir::TF::SliceOp>(
719 all_scatter.getLoc(), output_type, all_scatter.input(),
720 offset_squeezed.output(), slice_shape_value);
721
722 SetSingleLayoutOnOp(result, desired_layout);
723
724 all_scatter.output().replaceAllUsesExcept(result.output(), result);
725 all_scatter.erase();
726
727 return mlir::LogicalResult::success();
728 }
729
730 struct DTensorAllReduceLowering
731 : public DTensorAllReduceLoweringBase<DTensorAllReduceLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllReduceLowering732 void runOnOperation() override {
733 mlir::MLIRContext& context = getContext();
734 mlir::ModuleOp module = getOperation();
735
736 // Find all DTensorAllReduce ops.
737 llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> all_reduces;
738 module.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
739 all_reduces.emplace_back(all_reduce);
740 });
741
742 // Replace every DTensorAllReduce op with device-specific implementations.
743 for (auto& all_reduce : all_reduces)
744 if (mlir::failed(LowerAllReduceOp(context, all_reduce)))
745 return signalPassFailure();
746 }
747 };
748
749 struct DTensorReduceScatterLowering
750 : public DTensorReduceScatterLoweringBase<DTensorReduceScatterLowering> {
getDependentDialectstensorflow::dtensor::__anone995ee200311::DTensorReduceScatterLowering751 void getDependentDialects(mlir::DialectRegistry& registry) const override {
752 registry.insert<mlir::dtensor::DTensorDialect>();
753 }
754
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorReduceScatterLowering755 void runOnOperation() override {
756 mlir::ModuleOp module = getOperation();
757
758 // Find all DTensorAllReduce ops.
759 llvm::SmallVector<mlir::TF::DTensorReduceScatterOp, 4> all_reduces;
760 module.walk([&](mlir::TF::DTensorReduceScatterOp all_reduce) {
761 all_reduces.emplace_back(all_reduce);
762 });
763
764 // Replace every DTensorAllReduce op with device-specific implementations.
765 for (auto& all_reduce : all_reduces)
766 if (mlir::failed(LowerReduceScatterOp(all_reduce)))
767 return signalPassFailure();
768 }
769 };
770
771 struct DTensorAllGatherLowering
772 : public DTensorAllGatherLoweringBase<DTensorAllGatherLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllGatherLowering773 void runOnOperation() override {
774 mlir::ModuleOp module = getOperation();
775
776 // Process all DTensorAllGather ops.
777 llvm::SmallVector<mlir::TF::DTensorAllGatherOp, 4> all_gathers;
778 module.walk([&](mlir::TF::DTensorAllGatherOp all_gather) {
779 all_gathers.emplace_back(all_gather);
780 });
781
782 for (mlir::TF::DTensorAllGatherOp all_gather : all_gathers)
783 if (mlir::failed(LowerAllGatherOp(all_gather)))
784 return signalPassFailure();
785 }
786 };
787
788 struct DTensorAllScatterLowering
789 : public DTensorAllScatterLoweringBase<DTensorAllScatterLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllScatterLowering790 void runOnOperation() override {
791 mlir::ModuleOp module = getOperation();
792
793 // Process all DTensorAllScatter ops.
794 llvm::SmallVector<mlir::TF::DTensorAllScatterOp, 4> all_scatters;
795 module.walk([&](mlir::TF::DTensorAllScatterOp all_scatter) {
796 all_scatters.emplace_back(all_scatter);
797 });
798
799 for (mlir::TF::DTensorAllScatterOp all_scatter : all_scatters)
800 if (mlir::failed(LowerAllScatterOp(all_scatter)))
801 return signalPassFailure();
802 }
803 };
804
805 } // namespace
806
807 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllReduceLoweringPass()808 CreateDTensorAllReduceLoweringPass() {
809 return std::make_unique<DTensorAllReduceLowering>();
810 }
811
812 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorReduceScatterLoweringPass()813 CreateDTensorReduceScatterLoweringPass() {
814 return std::make_unique<DTensorReduceScatterLowering>();
815 }
816
817 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllGatherLoweringPass()818 CreateDTensorAllGatherLoweringPass() {
819 return std::make_unique<DTensorAllGatherLowering>();
820 }
821
822 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllScatterLoweringPass()823 CreateDTensorAllScatterLoweringPass() {
824 return std::make_unique<DTensorAllScatterLowering>();
825 }
826
827 } // namespace dtensor
828 } // namespace tensorflow
829