xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/utils/collective_lowering.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <atomic>
17 #include <string>
18 
19 #include "llvm/ADT/APFloat.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/Types.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/Pass/Pass.h"  // from @llvm-project
36 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/str_util.h"
44 #include "tensorflow/dtensor/cc/constants.h"
45 #include "tensorflow/dtensor/cc/dstatus.h"
46 #include "tensorflow/dtensor/cc/dtensor_utils.h"
47 #include "tensorflow/dtensor/mlir/collectives_common.h"
48 #include "tensorflow/dtensor/mlir/device_utils.h"
49 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
50 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
51 #include "tensorflow/dtensor/mlir/dtensor_location.h"
52 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
53 #include "tensorflow/dtensor/mlir/group_assignment.h"
54 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
55 #include "tensorflow/dtensor/mlir/layout_parsing.h"
56 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
57 #include "tensorflow/dtensor/mlir/value_utils.h"
58 
59 namespace tensorflow {
60 namespace dtensor {
61 
62 namespace {
63 
64 namespace ops_util = ::mlir::TF::collection_ops_util;
65 constexpr int32 kUninitializedGroupKey = 0;
66 
67 // A counter that is used to generate shift base values for TF collective group
68 // and instance keys. Every TF collective AllReduce op in a program gets a value
69 // from this counter. The value increments according to the position of the
70 // AllReduce op in the program. Different hosts go through exactly the same MLIR
71 // logic and therefore iterate over AllReduce ops in the same order (even in the
72 // presence of control flow), so they should indenpendently generate the same
73 // counter value for matching AllReduce ops across hosts.
74 static std::atomic<int32> tf_collective_key_base{0};
75 
76 }  // namespace
77 }  // namespace dtensor
78 }  // namespace tensorflow
79 
80 #ifdef PLATFORM_GOOGLE
81 // Use the Google internal version of EmitAllReduceForXla.
82 #include "collective_lowering_google.inc"
83 #else
84 namespace tensorflow {
85 namespace dtensor {
86 namespace {
87 constexpr char kCrossReplica[] = "CrossReplica";
88 
EmitAllReduceForXla(mlir::MLIRContext & context,mlir::OpBuilder & builder,mlir::TF::DTensorAllReduceOp all_reduce,mlir::DenseIntElementsAttr group_assignment_attr,int32 key_base,mlir::Operation ** final_op)89 mlir::LogicalResult EmitAllReduceForXla(
90     mlir::MLIRContext& context, mlir::OpBuilder& builder,
91     mlir::TF::DTensorAllReduceOp all_reduce,
92     mlir::DenseIntElementsAttr group_assignment_attr, int32 key_base,
93     mlir::Operation** final_op) {
94   // For TPUs, lower to XlaAllReduce straightforwardly.
95   *final_op = builder.create<mlir::TF::XlaAllReduceOp>(
96       all_reduce.getLoc(), all_reduce.getResult().getType(), all_reduce.input(),
97       all_reduce.group_assignment(), all_reduce.reduce_opAttr(),
98       builder.getStringAttr(kCrossReplica));
99   return mlir::success();
100 }
101 }  // namespace
102 }  // namespace dtensor
103 }  // namespace tensorflow
104 #endif
105 
106 namespace tensorflow {
107 namespace dtensor {
108 namespace {
109 // Emit a host CollectiveReduce op for the given input.
110 // `group_assignment` is used to generate an array of group keys.
111 // `device_id` slices into that array to get the key for a device at runtime.
112 // `key_base` is the common part shared by all group keys.
113 // `device_id` is an mlir::Value that will contain the device ID at runtime.
114 // `host_group_size` sets host collective group size. It should match the number
115 //   of active devices running the host collective and supplying device IDs,
116 //   else the host collective will crash or hang.
EmitCollectiveReduce(mlir::OpBuilder & builder,const mlir::Location & loc,mlir::Value input,const std::string & reduce_op_str,const mlir::DenseIntElementsAttr & group_assignment,int32 key_base,mlir::Value device_id,int32 host_group_size,const mlir::StringRef device_type)117 mlir::Operation* EmitCollectiveReduce(
118     mlir::OpBuilder& builder, const mlir::Location& loc, mlir::Value input,
119     const std::string& reduce_op_str,
120     const mlir::DenseIntElementsAttr& group_assignment, int32 key_base,
121     mlir::Value device_id, int32 host_group_size,
122     const mlir::StringRef device_type) {
123   DCHECK_EQ(group_assignment.getType().getRank(), 2);
124   auto shape = group_assignment.getType().getShape();
125   const int32 num_groups = shape[0];
126   const int32 group_size = shape[1];
127   const int32 num_devices = num_groups * group_size;
128   const mlir::TensorType input_type =
129       input.getType().dyn_cast<mlir::TensorType>();
130 
131   const bool need_int32_to_int64_upcast =
132       (device_type.endswith("GPU") && input_type &&
133        input_type.getElementType().isInteger(32));
134 
135   if (need_int32_to_int64_upcast) {
136     LOG(WARNING) << "On GPU, collective reduce of int32 is not supported. "
137                     "Casting to int64 as a workaround: "
138                  << mlir::debugString(loc);
139 
140     mlir::TF::CastOp cast_to_int64 = builder.create<mlir::TF::CastOp>(
141         loc,
142         mlir::RankedTensorType::get(input_type.getShape(),
143                                     builder.getIntegerType(64)),
144         input);
145     input = cast_to_int64.getResult();
146   }
147   mlir::Value group_key_scalar;
148   llvm::SmallVector<int32, 4> device_id_to_group_key(num_devices);
149   device_id_to_group_key.resize(num_devices, kUninitializedGroupKey);
150   // 21 bits + 11 bits allow roughly 2M all-reduces in one program and up to a
151   // full DF pod.
152   DCHECK_LT(key_base, 1L << 21) << "Reaching 2^21 all-reduces.";
153   DCHECK_LE(num_devices, 1L << 11) << "Exceeding 2048 groups.";
154   for (const auto& it :
155        llvm::enumerate(group_assignment.getValues<llvm::APInt>())) {
156     int32 device_id = it.value().getSExtValue();
157     DCHECK_LE(0, device_id);
158     DCHECK_LT(device_id, num_devices);
159     DCHECK_EQ(device_id_to_group_key[device_id], kUninitializedGroupKey);
160     const int32 group_id = static_cast<int32>(it.index()) / group_size;
161     device_id_to_group_key[device_id] = (key_base << 11) ^ group_id;
162   }
163 
164   // Create a scalar group key by slicing device_id_to_group_key with
165   // device_id.
166   auto group_key_slice = builder.create<mlir::TF::SliceOp>(
167       loc, EffectivelyScalarR1Type(builder.getIntegerType(32)),
168       /*input=*/IntConst(builder, loc, device_id_to_group_key),
169       /*begin=*/device_id,
170       /*size=*/IntConst(builder, loc, {1}));
171   auto group_key_reshape = builder.create<mlir::TF::ReshapeOp>(
172       loc, /*tensor=*/group_key_slice.getResult(),
173       /*shape=*/ops_util::GetR1Const({}, builder, loc));
174   group_key_scalar = group_key_reshape.getResult();
175 
176   // Generate a unique instance key for this collective.
177   mlir::Value instance_key_scalar =
178       ops_util::CreateScalarConst(static_cast<int32>(key_base), builder, loc);
179 
180   const bool is_mean_op = reduce_op_str == kReduceOpMean;
181   mlir::Value group_size_scalar =
182       ops_util::CreateScalarConst(host_group_size, builder, loc);
183   auto collective_reduce = builder.create<mlir::TF::CollectiveReduceV2Op>(
184       loc, /*output_type=*/input.getType(), input, group_size_scalar,
185       group_key_scalar, instance_key_scalar,
186       /*ordering_token=*/mlir::ValueRange({}),
187       /*merge_op=*/builder.getStringAttr(is_mean_op ? "Add" : reduce_op_str),
188       /*final_op=*/builder.getStringAttr(is_mean_op ? "Div" : "Id"),
189       /*communication_hint=*/builder.getStringAttr(""),
190       /*timeout_seconds=*/builder.getF32FloatAttr(0.),
191       /*max_subdivs_per_device=*/builder.getI64IntegerAttr(16));
192   SetSingleLayoutOnOp(collective_reduce, Layout::Empty());
193   if (need_int32_to_int64_upcast) {
194     return builder.create<mlir::TF::CastOp>(
195         loc,
196         mlir::RankedTensorType::get(input_type.getShape(),
197                                     builder.getIntegerType(32)),
198         collective_reduce);
199   }
200   return collective_reduce;
201 }
202 
LowerAllReduceOpImpl(mlir::MLIRContext & context,mlir::OpBuilder & builder,mlir::TF::DTensorAllReduceOp all_reduce,mlir::Value * value)203 mlir::LogicalResult LowerAllReduceOpImpl(
204     mlir::MLIRContext& context, mlir::OpBuilder& builder,
205     mlir::TF::DTensorAllReduceOp all_reduce, mlir::Value* value) {
206   mlir::Location loc = all_reduce.getLoc();
207   StatusOr<Layout> output_layout =
208       ExtractRequiredSingleLayoutFromOp(all_reduce);
209   if (!output_layout.ok()) {
210     return all_reduce.emitOpError(output_layout.status().error_message());
211   }
212   mlir::DenseIntElementsAttr group_assignment_attr;
213   if (!matchPattern(all_reduce.group_assignment(),
214                     m_Constant(&group_assignment_attr)))
215     return mlir::emitError(loc, "group_assigment must be a constant.");
216   if (group_assignment_attr.getType().getRank() != 2)
217     return mlir::emitError(loc, "group_assignment should have two dimensions.");
218   int32 group_size = group_assignment_attr.getType().getShape()[1];
219 
220   // This will become more general when Topology is properly defined.
221   const bool is_tpu = all_reduce.device_type().endswith("TPU");
222   // Use an atomic counter to generate bases for group and instance keys.
223   int32 key_base = tf_collective_key_base++;
224 
225   mlir::Operation* final_op;
226   if (is_tpu) {
227     if (mlir::failed(EmitAllReduceForXla(context, builder, all_reduce,
228                                          group_assignment_attr, key_base,
229                                          &final_op))) {
230       return mlir::failure();
231     }
232   } else {
233     // Generate CPU/GPU collective. CPU/GPU collectives identify groups on
234     // the basis of a local group key. We must generate an appropriate group
235     // key based on our device ID. This is expressible as an algebraic
236     // function of the device id, but we instead encode the
237     // device_id->group_key as an explicit map value and lookup the result
238     // at runtime. Note that the order we map devices to partitions is not
239     // deterministic, and moreover if we have multiple distinct reductions
240     // groups in one program reducing over all hosts and reducing over pairs
241     // of hosts, we need unique ids for each case.
242     mlir::Value device_id = ops_util::ReshapeScalarToSizeType(
243         builder, DeviceId(all_reduce.getResult()).ValueOrDie(), loc);
244     // TODO(b/188076080): Clean up device id.
245     mlir::Value start_device_id = ops_util::GetR1Const(
246         {(*output_layout).mesh().min_global_device_id()}, builder, loc);
247     mlir::Value relative_device_id =
248         builder.create<mlir::TF::SubOp>(loc, device_id, start_device_id);
249 
250     final_op = EmitCollectiveReduce(
251         builder, loc, all_reduce.input(), all_reduce.reduce_op().str(),
252         group_assignment_attr, key_base, relative_device_id,
253         /*host_group_size=*/group_size, all_reduce.device_type().str());
254   }
255   SetSingleLayoutOnOp(final_op, *output_layout);
256   *value = final_op->getResult(0);
257   return mlir::success();
258 }
259 
260 template <class ReduceOpType>
ConvertBoolReduce(ReduceOpType reduce_op)261 mlir::LogicalResult ConvertBoolReduce(ReduceOpType reduce_op) {
262   mlir::OpBuilder builder(reduce_op);
263   const mlir::Location loc = reduce_op.getLoc();
264   const mlir::Type output_type = reduce_op.getResult().getType();
265   const mlir::Type input_type = reduce_op.getOperand(0).getType();
266 
267   // Handle bools by first casting to int32 and swapping All/Any for Min/Max.
268   const mlir::TensorType& tensor_input_type =
269       input_type.dyn_cast<mlir::TensorType>();
270   const mlir::TensorType& tensor_output_type =
271       output_type.dyn_cast<mlir::TensorType>();
272   if (tensor_input_type && tensor_output_type &&
273       tensor_input_type.getElementType().isInteger(1)) {
274     if (reduce_op.reduce_opAttr().getValue().str() == kReduceOpAll)
275       reduce_op.reduce_opAttr(builder.getStringAttr(std::string(kReduceOpMin)));
276     else if (reduce_op.reduce_opAttr().getValue().str() == kReduceOpAny)
277       reduce_op.reduce_opAttr(builder.getStringAttr(std::string(kReduceOpMax)));
278     else
279       return reduce_op.emitOpError()
280              << "reduce for boolean only supports 'All' or 'Any' reduction. "
281              << "Received '" << reduce_op.reduce_opAttr().getValue().str()
282              << "'";
283     const mlir::Type integer_input_type = mlir::RankedTensorType::get(
284         tensor_input_type.getShape(), builder.getIntegerType(32));
285     mlir::TF::CastOp cast_to_int32 = builder.create<mlir::TF::CastOp>(
286         loc, integer_input_type, reduce_op.input());
287     reduce_op.setOperand(0, cast_to_int32.y());
288     const mlir::Type integer_output_type = mlir::RankedTensorType::get(
289         tensor_output_type.getShape(), builder.getIntegerType(32));
290     reduce_op.output().setType(integer_output_type);
291 
292     // Add cast back to boolean after reduction.
293     mlir::Value result = reduce_op.output();
294     builder.setInsertionPointAfter(reduce_op);
295     mlir::TF::CastOp cast_to_bool =
296         builder.create<mlir::TF::CastOp>(loc, output_type, result);
297     StatusOr<Layout> result_layout =
298         ExtractRequiredSingleLayoutFromOp(result.getDefiningOp());
299     if (!result_layout.ok()) {
300       return reduce_op.emitOpError(result_layout.status().error_message());
301     }
302     SetSingleLayoutOnOp(cast_to_bool, *result_layout);
303     reduce_op.output().replaceAllUsesExcept(cast_to_bool.y(), cast_to_bool);
304   }
305 
306   return mlir::success();
307 }
308 
LowerAllReduceOp(mlir::MLIRContext & context,mlir::TF::DTensorAllReduceOp all_reduce)309 mlir::LogicalResult LowerAllReduceOp(mlir::MLIRContext& context,
310                                      mlir::TF::DTensorAllReduceOp all_reduce) {
311   if (mlir::failed(ConvertBoolReduce<mlir::TF::DTensorAllReduceOp>(all_reduce)))
312     return mlir::failure();
313 
314   mlir::OpBuilder builder(all_reduce);
315   mlir::Value result;
316   if (mlir::failed(LowerAllReduceOpImpl(context, builder, all_reduce, &result)))
317     return mlir::failure();
318 
319   all_reduce.replaceAllUsesWith(result);
320   all_reduce.erase();
321   return mlir::success();
322 }
323 
LowerReduceScatterOp(mlir::TF::DTensorReduceScatterOp reduce_scatter)324 mlir::LogicalResult LowerReduceScatterOp(
325     mlir::TF::DTensorReduceScatterOp reduce_scatter) {
326   mlir::Location loc = reduce_scatter.getLoc();
327 
328   StatusOr<Layout> output_layout =
329       ExtractRequiredSingleLayoutFromOp(reduce_scatter);
330   if (!output_layout.ok()) {
331     return reduce_scatter.emitOpError(output_layout.status().error_message());
332   }
333   mlir::DenseIntElementsAttr group_assignment_attr;
334   if (!matchPattern(reduce_scatter.group_assignment(),
335                     m_Constant(&group_assignment_attr)))
336     return reduce_scatter.emitOpError("group_assigment must be a constant.");
337   if (group_assignment_attr.getType().getRank() != 2)
338     return reduce_scatter.emitOpError(
339         "group_assignment should have two dimensions.");
340 
341   mlir::OpBuilder builder(reduce_scatter);
342   if (reduce_scatter.device_type().endswith("TPU")) {
343     if (mlir::failed(ConvertBoolReduce<mlir::TF::DTensorReduceScatterOp>(
344             reduce_scatter)))
345       return mlir::failure();
346     // For TPUs, lower to XlaReduceScatter straightforwardly.
347     mlir::Operation* xla_reduce_scatter =
348         builder.create<mlir::TF::XlaReduceScatterOp>(
349             loc, reduce_scatter.getResult().getType(), reduce_scatter.input(),
350             reduce_scatter.group_assignment(),
351             reduce_scatter.scatter_dimension(), reduce_scatter.reduce_opAttr());
352     SetSingleLayoutOnOp(xla_reduce_scatter, *output_layout);
353     reduce_scatter.replaceAllUsesWith(xla_reduce_scatter);
354   } else {
355     // For non TPUs device, decompose to DTensorAllReduce+DTensorAllScatter.
356     StatusOr<Layout> input_layout =
357         ExtractRequiredLayoutFromOperand(reduce_scatter.input());
358     if (!input_layout.ok()) {
359       // If input layout is not defined, modify the output_layout based on the
360       // scattered dimension.
361       mlir::DenseIntElementsAttr scatter_attr;
362       if (!matchPattern(reduce_scatter.scatter_dimension(),
363                         m_Constant(&scatter_attr))) {
364         return reduce_scatter.emitOpError(
365             "Scatter dimension not constant integer array.");
366       }
367       mlir::APInt scatter_dim = *scatter_attr.begin();
368       std::vector<string> input_sharding_spec =
369           output_layout->sharding_spec_strs();
370       input_sharding_spec[scatter_dim.getSExtValue()] = Layout::kUnshardedDim;
371       input_layout =
372           Layout::GetLayout(input_sharding_spec, output_layout->mesh());
373     }
374 
375     if (!input_layout.ok()) {
376       return reduce_scatter.emitOpError(input_layout.status().error_message());
377     }
378 
379     auto dtensor_allreduce = builder.create<mlir::TF::DTensorAllReduceOp>(
380         reduce_scatter.getLoc(), reduce_scatter.getOperand(0).getType(),
381         reduce_scatter.getOperand(0), reduce_scatter.group_assignment(),
382         reduce_scatter.reduce_op(), reduce_scatter.device_type());
383     SetSingleLayoutOnOp(dtensor_allreduce, *input_layout);
384 
385     mlir::Operation* dtensor_all_scatter =
386         builder.create<mlir::TF::DTensorAllScatterOp>(
387             reduce_scatter.getLoc(), reduce_scatter.getResult().getType(),
388             dtensor_allreduce.getResult(),
389             mlir::dtensor::LayoutAttr::get(builder.getContext(), *input_layout),
390             mlir::dtensor::LayoutAttr::get(builder.getContext(),
391                                            *output_layout));
392     SetSingleLayoutOnOp(dtensor_all_scatter, *output_layout);
393     reduce_scatter.replaceAllUsesWith(dtensor_all_scatter);
394   }
395   reduce_scatter.erase();
396   return mlir::success();
397 }
398 
CreateZeroScalar(mlir::OpBuilder & builder,mlir::Location loc,mlir::RankedTensorType type)399 mlir::Value CreateZeroScalar(mlir::OpBuilder& builder, mlir::Location loc,
400                              mlir::RankedTensorType type) {
401   const mlir::Value zero_scalar = ops_util::CreateScalarConst(0, builder, loc);
402   return builder.create<mlir::TF::CastOp>(
403       loc, mlir::RankedTensorType::get({}, type.getElementType()), zero_scalar);
404 }
405 
406 // device_id is the relative device_id in a mesh (device id - mesh's 1st device
407 // id).
SelectElementsBasedOnId(mlir::OpBuilder & builder,mlir::Location loc,mlir::Value device_id,const llvm::SmallVectorImpl<int64> & candidates_flat,int64 num_devices,int64 output_shape_size)408 mlir::Value SelectElementsBasedOnId(
409     mlir::OpBuilder& builder, mlir::Location loc, mlir::Value device_id,
410     const llvm::SmallVectorImpl<int64>& candidates_flat, int64 num_devices,
411     int64 output_shape_size) {
412   // Reshape the flat list to a matrix of shape num_devices * output_shape_size.
413   const mlir::Value candidates_flat_const =
414       ops_util::GetR1Const(candidates_flat, builder, loc);
415   const mlir::Value candidates_shape =
416       ops_util::GetR1Const({num_devices, output_shape_size}, builder, loc);
417   const mlir::Value candidates = builder.create<mlir::TF::ReshapeOp>(
418       loc, candidates_flat_const, candidates_shape);
419 
420   // Add a zero after the only value in the 1x1 device_id tensor.
421   const mlir::Value device_id_paddings = builder.create<mlir::TF::ReshapeOp>(
422       loc, ops_util::GetR1Const({0, 1}, builder, loc),
423       ops_util::GetR1Const({1, 2}, builder, loc));
424   const mlir::Value device_id_padded = builder.create<mlir::TF::PadOp>(
425       loc, candidates_shape.getType(), /*input=*/device_id,
426       /*paddings=*/device_id_paddings);
427 
428   // Slice a vertical vector out of the 2D candidates matrix.
429   const mlir::RankedTensorType chosen_shape_type = mlir::RankedTensorType::get(
430       {1, output_shape_size}, builder.getIntegerType(32));
431   const mlir::Value chosen_shape_const =
432       ops_util::GetR1Const(chosen_shape_type.getShape(), builder, loc);
433   const mlir::Value chosen = builder.create<mlir::TF::SliceOp>(
434       loc, chosen_shape_type, /*input=*/candidates, /*begin=*/device_id_padded,
435       /*size=*/chosen_shape_const);
436 
437   // Remove the leading dimension of size 1 before returning the result.
438   return builder.create<mlir::TF::ReshapeOp>(
439       loc, chosen, ops_util::GetR1Const({output_shape_size}, builder, loc));
440 }
441 
LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather)442 mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) {
443   const Layout src_layout = all_gather.input_layout();
444   const Layout tgt_layout = all_gather.output_layout();
445 
446   llvm::SmallVector<int64, 4> concat_dims;
447   for (int64 i = 0; i < src_layout.rank(); ++i)
448     if (src_layout.num_shards_for_dim(src_layout.dim(i)) > 1 &&
449         Layout::IsUnshardedDimension(tgt_layout.sharding_spec(i)))
450       concat_dims.push_back(i);
451 
452   mlir::OpBuilder builder(all_gather);
453   builder.setInsertionPointAfter(all_gather);
454 
455   if (concat_dims.empty()) {
456     mlir::TF::IdentityOp identity = builder.create<mlir::TF::IdentityOp>(
457         all_gather.getLoc(), all_gather.input().getType(), all_gather.input());
458     SetSingleLayoutOnOp(identity, tgt_layout);
459 
460     all_gather.output().replaceAllUsesWith(identity);
461     all_gather.erase();
462     return mlir::success();
463   }
464 
465   const mlir::RankedTensorType input_type =
466       all_gather.input().getType().dyn_cast<mlir::RankedTensorType>();
467   const mlir::RankedTensorType output_type =
468       all_gather.output().getType().dyn_cast<mlir::RankedTensorType>();
469 
470   if (!input_type)
471     return all_gather.emitOpError() << "input type is not a RankedTensorType";
472   if (!output_type)
473     return all_gather.emitOpError() << "output type is not a RankedTensorType";
474 
475   const std::vector<int64_t> output_shape = output_type.getShape();
476 
477   // Construct an output with zeros of the correct size, and add our
478   // local slice into it. We then all reduce to compute a final result.
479   const mlir::Location loc = DT_LOC(all_gather.getLoc());
480   const mlir::Value output_shape_const = Int64Const(builder, loc, output_shape);
481   const mlir::Value zero_scalar = CreateZeroScalar(builder, loc, input_type);
482   const mlir::Value zeros =
483       builder.create<mlir::TF::FillOp>(loc, output_shape_const, zero_scalar);
484 
485   // For every possible device ID, generate its strided slice ranges. Store all
486   // ranges---num_devices * output_shape_size * (begin, end, stride)---as three
487   // flat lists.
488   // Consider making this a generalized N-dimensional helper on Layout.
489   const int64 num_devices = src_layout.num_devices();
490   const int64 output_shape_size = output_shape.size();
491   llvm::SmallVector<int64, 4> device_id_to_begin_flat;
492   llvm::SmallVector<int64, 4> device_id_to_end_flat;
493   llvm::SmallVector<int64, 4> device_id_to_strides_flat;
494   for (int64 device_id = 0; device_id < num_devices; ++device_id) {
495     for (int64 i = 0; i < output_shape_size; ++i) {
496       if (llvm::find(concat_dims, i) == std::end(concat_dims)) {
497         // For unsharded dimensions, the slice range is [0, dim_size).
498         device_id_to_begin_flat.push_back(0);
499         device_id_to_end_flat.push_back(output_shape[i]);
500       } else {
501         // For sharded dimensions, the slice range is [step * device_id, step *
502         // (device_id + 1)), where step = dim_size / num_of_shards.
503         StatusOr<DeviceLocation> device_loc_or_status =
504             src_layout.device_location(device_id);
505         if (!device_loc_or_status.ok())
506           return all_gather.emitOpError()
507                  << device_loc_or_status.status().error_message();
508         const DeviceLocation device_loc = device_loc_or_status.ValueOrDie();
509         const int32 mesh_idx = src_layout.mesh()
510                                    .idx_for_dim(src_layout.sharding_spec(i))
511                                    .ValueOrDie();
512         const int64 device_offset = device_loc[mesh_idx];
513         const int64 step = output_shape[i] / src_layout.num_shards()[i];
514         device_id_to_begin_flat.push_back(step * device_offset);
515         device_id_to_end_flat.push_back(step * device_offset + step);
516       }
517       // We need to change every element in the selected slice, so stride is 1
518       // for every dimension.
519       device_id_to_strides_flat.push_back(1);
520     }
521   }
522 
523   // Resize three flat lists to 2D matrices and select one vertical vector out
524   // of every matrix based on device ID.
525   StatusOr<mlir::Value> device_id_scalar_or_status =
526       DeviceId(all_gather.input());
527   if (!device_id_scalar_or_status.ok())
528     return all_gather.emitOpError()
529            << device_id_scalar_or_status.status().error_message();
530   const mlir::Value device_id_scalar = device_id_scalar_or_status.ValueOrDie();
531   const mlir::Value device_id =
532       ops_util::ReshapeScalarToSizeType(builder, device_id_scalar, loc);
533   // TODO(b/188076080): Clean up device id.
534   const mlir::Value start_device_id = ops_util::GetR1Const(
535       {src_layout.mesh().min_global_device_id()}, builder, loc);
536   const mlir::Value relative_device_id =
537       builder.create<mlir::TF::SubOp>(loc, device_id, start_device_id);
538   const mlir::Value begin = SelectElementsBasedOnId(
539       builder, loc, relative_device_id, device_id_to_begin_flat, num_devices,
540       output_shape_size);
541   const mlir::Value end = SelectElementsBasedOnId(
542       builder, loc, relative_device_id, device_id_to_end_flat, num_devices,
543       output_shape_size);
544   const mlir::Value strides = SelectElementsBasedOnId(
545       builder, loc, relative_device_id, device_id_to_strides_flat, num_devices,
546       output_shape_size);
547 
548   // Fill in the local portion by slicing into the correct subrange.
549   mlir::Value update_result;
550   if (src_layout.mesh().is_tpu_mesh()) {
551     if (!tgt_layout.mesh().is_tpu_mesh())
552       return all_gather.emitOpError()
553              << "source and target layout are not both on tpu";
554     update_result = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>(
555         loc, zeros.getType(), /*input=*/zeros,
556         /*update=*/all_gather.input(), /*indices=*/begin);
557   } else {
558     update_result = builder.create<mlir::TF::TensorStridedSliceUpdateOp>(
559         loc, zeros.getType(),
560         /*input=*/zeros, begin, end, strides,
561         /*value=*/all_gather.input());
562   }
563 
564   // All reduce among concatenated dimensions.
565   absl::flat_hash_set<std::string> reduced_dims;
566   for (int i : concat_dims) reduced_dims.insert(src_layout.sharding_spec(i));
567 
568   auto partitions_or_status =
569       GetAllReducePartitionsFromReducedDims(src_layout, reduced_dims);
570   if (!partitions_or_status.ok())
571     return all_gather.emitOpError()
572            << partitions_or_status.status().error_message();
573   auto partitions = partitions_or_status.ValueOrDie();
574   const int32 num_partitions = partitions.size();
575   assert(num_partitions <= num_devices);
576   if (num_partitions == num_devices) {
577     // TODO(unknown): Is this check needed? Since we check that num_shards for
578     // each reduced_dims in the src layout is > 1, I think we always need
579     // communication.
580     // If every device lives in its own partition, we don't need to emit a
581     // collective.
582     SetSingleLayoutOnOp(update_result.getDefiningOp(), tgt_layout);
583     all_gather.output().replaceAllUsesWith(update_result);
584     all_gather.erase();
585     return mlir::success();
586   }
587 
588   std::vector<int32> partitions_flat;
589   for (auto& p : partitions) {
590     if (p.second.size() != partitions.begin()->second.size())
591       return all_gather.emitOpError() << "partitions had different sizes -- "
592                                          "this is not supported in MLIR.";
593     partitions_flat.insert(partitions_flat.end(), p.second.begin(),
594                            p.second.end());
595   }
596   const int32 partition_size = partitions.begin()->second.size();
597   const mlir::RankedTensorType shaped_type = mlir::RankedTensorType::get(
598       {num_partitions, partition_size},
599       mlir::IntegerType::get(builder.getContext(), 32));
600   const mlir::DenseIntElementsAttr group_assignment =
601       mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat);
602   StatusOr<std::string> device_type_or_status =
603       DeviceTypeFromMesh(src_layout.mesh());
604   if (!device_type_or_status.ok())
605     return all_gather.emitOpError()
606            << device_type_or_status.status().error_message();
607   const std::string device_type = device_type_or_status.ValueOrDie();
608 
609   // Support bool types by switching to Any reduce rather than Add. For each
610   // position in the tensor, only one task in the reduction group can have a 1.
611   // This is sufficient.
612   const mlir::TensorType type =
613       update_result.getType().dyn_cast<mlir::TensorType>();
614   absl::string_view reduce_type = kReduceOpAdd;
615   if (type && type.getElementType().isInteger(1)) reduce_type = kReduceOpAny;
616   mlir::TF::DTensorAllReduceOp all_reduce =
617       builder.create<mlir::TF::DTensorAllReduceOp>(
618           loc, update_result.getType(), update_result,
619           builder.create<mlir::TF::ConstOp>(loc, group_assignment),
620           builder.getStringAttr(std::string(reduce_type)),
621           builder.getStringAttr(device_type));
622   SetSingleLayoutOnOp(all_reduce, tgt_layout);
623 
624   all_gather.output().replaceAllUsesWith(all_reduce.getResult());
625   all_gather.erase();
626   return mlir::LogicalResult::success();
627 }
628 
LowerAllScatterOp(mlir::TF::DTensorAllScatterOp all_scatter)629 mlir::LogicalResult LowerAllScatterOp(
630     mlir::TF::DTensorAllScatterOp all_scatter) {
631   const Layout original_layout = all_scatter.input_layout();
632   const Layout desired_layout = all_scatter.output_layout();
633 
634   mlir::tf_device::ClusterOp cluster =
635       all_scatter->getParentOfType<mlir::tf_device::ClusterOp>();
636   StatusOr<mlir::Value> mesh_coordinates_status =
637       GetMeshCoordinatesFromCluster(cluster);
638   if (!mesh_coordinates_status.ok())
639     return all_scatter.emitOpError()
640            << mesh_coordinates_status.status().error_message();
641   mlir::Value mesh_coordinates = mesh_coordinates_status.ValueOrDie();
642 
643   // We need to compute the slice offset, which is dynamic based on the id.
644   //
645   // To compute the offset:
646   // For axes where there is no splitting, the offset is simply 0.
647   // For axes where there is splitting, say axis a, if new local size of that
648   // axis is k, then the offset for the split is
649   // mesh_coordinates[sharding_spec[a]]*k where sharding_spec[i] is the
650   // mesh_dimension for a. This computation can be encoded in small 2d matrix of
651   // shape [mesh.rank(), layout.rank()] where the [i, j]'th entry is k if
652   // sharding_spec[j]=i and this is a dimension with split and 0 otherwise.
653 
654   mlir::RankedTensorType output_type =
655       all_scatter.output().getType().dyn_cast<mlir::RankedTensorType>();
656   if (!output_type)
657     return all_scatter.emitOpError() << "input must have static rank";
658 
659   llvm::ArrayRef<int64_t> output_shape = output_type.getShape();
660 
661   // We use a flat list here. The 2D matrix will be of shape
662   // [original_layout.mesh().rank(), original_layout.rank()]
663   // so the 2D index [i, j] corresponds to the 1D index of
664   // [i * original_layout.rank() + j].
665   std::vector<int32> matrix(original_layout.mesh().rank() *
666                             original_layout.rank());
667   for (int i = 0; i < original_layout.rank(); ++i) {
668     if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i)) {
669       if (mlir::ShapedType::isDynamic(output_shape[i])) {
670         return all_scatter.emitOpError()
671                << "EmitAllScatter requires slice on input axis " << i
672                << " which is dynamic. This is not supported";
673       }
674 
675       // We already checked above that original_layout.sharding_spec(i) is
676       // unsharded.
677       int mesh_dim_index = desired_layout.mesh().GetMeshDimIndexWithName(
678           desired_layout.sharding_spec(i));
679       matrix[mesh_dim_index * original_layout.rank() + i] = output_shape[i];
680     }
681   }
682 
683   // Produce the constant tensor for the slice shape and the matrix.
684 
685   mlir::OpBuilder builder(all_scatter);
686 
687   // Slice shape has to be int32_t, as it must match the type of the offset to
688   // mlir::TF::SliceOp. The slice offset has to be int32_t as TPU doesn't have
689   // int64_t MatMul (which we use to compute the offset).
690   llvm::SmallVector<int32_t> output_shape_int32(output_shape.begin(),
691                                                 output_shape.end());
692   mlir::Value slice_shape_value =
693       IntConst(builder, all_scatter.getLoc(), output_shape_int32);
694 
695   mlir::RankedTensorType matrix_type = mlir::RankedTensorType::get(
696       {original_layout.mesh().rank(), original_layout.rank()},
697       builder.getIntegerType(32));
698   mlir::Attribute matrix_attr =
699       mlir::DenseIntElementsAttr::get(matrix_type, matrix);
700   mlir::Value matrix_value =
701       builder.create<mlir::TF::ConstOp>(all_scatter.getLoc(), matrix_attr)
702           .getResult();
703 
704   // Compute the offset from mult_matrix_value and mesh_coordinates.
705   mlir::TF::MatMulOp offset = builder.create<mlir::TF::MatMulOp>(
706       all_scatter.getLoc(),
707       mlir::RankedTensorType::get({1, original_layout.rank()},
708                                   builder.getIntegerType(32)),
709       mesh_coordinates, matrix_value);
710 
711   // Input to slice needs to be rank 1, so we need to sequeeze it.
712   mlir::TF::SqueezeOp offset_squeezed = builder.create<mlir::TF::SqueezeOp>(
713       all_scatter.getLoc(),
714       mlir::RankedTensorType::get({original_layout.rank()},
715                                   builder.getIntegerType(32)),
716       offset.product(), builder.getI64ArrayAttr({0}));
717 
718   auto result = builder.create<mlir::TF::SliceOp>(
719       all_scatter.getLoc(), output_type, all_scatter.input(),
720       offset_squeezed.output(), slice_shape_value);
721 
722   SetSingleLayoutOnOp(result, desired_layout);
723 
724   all_scatter.output().replaceAllUsesExcept(result.output(), result);
725   all_scatter.erase();
726 
727   return mlir::LogicalResult::success();
728 }
729 
730 struct DTensorAllReduceLowering
731     : public DTensorAllReduceLoweringBase<DTensorAllReduceLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllReduceLowering732   void runOnOperation() override {
733     mlir::MLIRContext& context = getContext();
734     mlir::ModuleOp module = getOperation();
735 
736     // Find all DTensorAllReduce ops.
737     llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> all_reduces;
738     module.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
739       all_reduces.emplace_back(all_reduce);
740     });
741 
742     // Replace every DTensorAllReduce op with device-specific implementations.
743     for (auto& all_reduce : all_reduces)
744       if (mlir::failed(LowerAllReduceOp(context, all_reduce)))
745         return signalPassFailure();
746   }
747 };
748 
749 struct DTensorReduceScatterLowering
750     : public DTensorReduceScatterLoweringBase<DTensorReduceScatterLowering> {
getDependentDialectstensorflow::dtensor::__anone995ee200311::DTensorReduceScatterLowering751   void getDependentDialects(mlir::DialectRegistry& registry) const override {
752     registry.insert<mlir::dtensor::DTensorDialect>();
753   }
754 
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorReduceScatterLowering755   void runOnOperation() override {
756     mlir::ModuleOp module = getOperation();
757 
758     // Find all DTensorAllReduce ops.
759     llvm::SmallVector<mlir::TF::DTensorReduceScatterOp, 4> all_reduces;
760     module.walk([&](mlir::TF::DTensorReduceScatterOp all_reduce) {
761       all_reduces.emplace_back(all_reduce);
762     });
763 
764     // Replace every DTensorAllReduce op with device-specific implementations.
765     for (auto& all_reduce : all_reduces)
766       if (mlir::failed(LowerReduceScatterOp(all_reduce)))
767         return signalPassFailure();
768   }
769 };
770 
771 struct DTensorAllGatherLowering
772     : public DTensorAllGatherLoweringBase<DTensorAllGatherLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllGatherLowering773   void runOnOperation() override {
774     mlir::ModuleOp module = getOperation();
775 
776     // Process all DTensorAllGather ops.
777     llvm::SmallVector<mlir::TF::DTensorAllGatherOp, 4> all_gathers;
778     module.walk([&](mlir::TF::DTensorAllGatherOp all_gather) {
779       all_gathers.emplace_back(all_gather);
780     });
781 
782     for (mlir::TF::DTensorAllGatherOp all_gather : all_gathers)
783       if (mlir::failed(LowerAllGatherOp(all_gather)))
784         return signalPassFailure();
785   }
786 };
787 
788 struct DTensorAllScatterLowering
789     : public DTensorAllScatterLoweringBase<DTensorAllScatterLowering> {
runOnOperationtensorflow::dtensor::__anone995ee200311::DTensorAllScatterLowering790   void runOnOperation() override {
791     mlir::ModuleOp module = getOperation();
792 
793     // Process all DTensorAllScatter ops.
794     llvm::SmallVector<mlir::TF::DTensorAllScatterOp, 4> all_scatters;
795     module.walk([&](mlir::TF::DTensorAllScatterOp all_scatter) {
796       all_scatters.emplace_back(all_scatter);
797     });
798 
799     for (mlir::TF::DTensorAllScatterOp all_scatter : all_scatters)
800       if (mlir::failed(LowerAllScatterOp(all_scatter)))
801         return signalPassFailure();
802   }
803 };
804 
805 }  // namespace
806 
807 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllReduceLoweringPass()808 CreateDTensorAllReduceLoweringPass() {
809   return std::make_unique<DTensorAllReduceLowering>();
810 }
811 
812 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorReduceScatterLoweringPass()813 CreateDTensorReduceScatterLoweringPass() {
814   return std::make_unique<DTensorReduceScatterLowering>();
815 }
816 
817 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllGatherLoweringPass()818 CreateDTensorAllGatherLoweringPass() {
819   return std::make_unique<DTensorAllGatherLowering>();
820 }
821 
822 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAllScatterLoweringPass()823 CreateDTensorAllScatterLoweringPass() {
824   return std::make_unique<DTensorAllScatterLowering>();
825 }
826 
827 }  // namespace dtensor
828 }  // namespace tensorflow
829