1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/collectives.h"
17
18 #include <cstdint>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/string_view.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/dtensor/cc/dstatus.h"
31 #include "tensorflow/dtensor/cc/tensor_layout.h"
32 #include "tensorflow/dtensor/mlir/collectives_common.h"
33 #include "tensorflow/dtensor/mlir/dtensor_location.h"
34 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35 #include "tensorflow/dtensor/mlir/layout_parsing.h"
36 #include "tensorflow/dtensor/mlir/shape_utils.h"
37 #include "tensorflow/dtensor/mlir/sparse_expander_common.h"
38 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
39 #include "tensorflow/dtensor/mlir/value_utils.h"
40
41 namespace tensorflow {
42 namespace dtensor {
43
44 namespace {
45
46 namespace ops_util = ::mlir::TF::collection_ops_util;
47
48 } // namespace
49
EmitAllGather(mlir::OpBuilder & builder,mlir::Value input,const dtensor::Layout & src_layout,const dtensor::Layout & tgt_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)50 StatusOr<mlir::Value> EmitAllGather(
51 mlir::OpBuilder& builder, mlir::Value input,
52 const dtensor::Layout& src_layout, const dtensor::Layout& tgt_layout,
53 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
54 if (src_layout.IsEquivalent(tgt_layout)) return input;
55
56 if (src_layout.rank() != tgt_layout.rank()) {
57 return errors::InvalidArgument(
58 "Expected source and target layout to have the same rank, got ",
59 src_layout.rank(), " vs ", tgt_layout.rank());
60 }
61
62 // Check that the tgt_layout is less sharded then src_layout.
63 for (int i = 0; i < src_layout.rank(); ++i) {
64 if (src_layout.sharding_spec(i) != tgt_layout.sharding_spec(i) &&
65 Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) {
66 return errors::InvalidArgument("source layout (", src_layout.ToString(),
67 ") for all gather is not less sharded "
68 "than the target layout (",
69 tgt_layout.ToString());
70 }
71 }
72
73 // For convenience, operate on explicit input shapes. This isn't necessary,
74 // as we could instead generate operations on top of the dynamic shape.
75 const mlir::TensorType input_type =
76 input.getType().dyn_cast<mlir::TensorType>();
77 if (!input_type) {
78 return errors::Internal(
79 llvm::formatv(
80 "Cannot cast input_type : {0} to TensorType. Shape must be "
81 " statically known before emitting AllGather. This should not "
82 "happen as we already cast it when getting its shape.",
83 input.getType())
84 .str());
85 }
86
87 TF_ASSIGN_OR_RETURN(mlir::TensorType global_type,
88 GlobalTypeFromLocalType(src_layout, input_type));
89 TF_ASSIGN_OR_RETURN(mlir::TensorType output_type,
90 LocalTypeFromGlobalType(tgt_layout, global_type));
91
92 mlir::Location loc = DT_LOC2(input.getLoc(), "DTensorAllGatherOp");
93 mlir::TF::DTensorAllGatherOp all_gather =
94 builder.create<mlir::TF::DTensorAllGatherOp>(
95 loc, output_type, input,
96 mlir::dtensor::LayoutAttr::get(builder.getContext(), src_layout),
97 mlir::dtensor::LayoutAttr::get(builder.getContext(), tgt_layout));
98 SetSingleLayoutOnOp(all_gather, tgt_layout);
99
100 if (newly_created_ops != nullptr) newly_created_ops->insert(all_gather);
101
102 return all_gather.output();
103 }
104
EmitAllScatter(mlir::OpBuilder & builder,const mlir::Value & original_value,const Layout & original_layout,const Layout & desired_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)105 StatusOr<const mlir::Value> EmitAllScatter(
106 mlir::OpBuilder& builder, const mlir::Value& original_value,
107 const Layout& original_layout, const Layout& desired_layout,
108 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
109 if (original_layout.IsEquivalent(desired_layout)) return original_value;
110
111 // Have an early return if desired layout is not more sharded then the
112 // original_layout.
113 assert(original_layout.rank() == desired_layout.rank());
114 for (int i = 0; i < original_layout.rank(); ++i) {
115 if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i) &&
116 Layout::IsShardedDimension(original_layout.sharding_spec(i))) {
117 return errors::InvalidArgument(
118 "EmitAllScatter was passed a desired_layout ",
119 desired_layout.ToString(),
120 " which was not more sharded than the original_layout ",
121 original_layout.ToString());
122 }
123 }
124
125 const mlir::TensorType input_type =
126 original_value.getType().dyn_cast<mlir::TensorType>();
127 if (!input_type)
128 return errors::InvalidArgument(
129 "input to EmitAllScatter does not have a TensorType");
130
131 TF_ASSIGN_OR_RETURN(const mlir::TensorType global_type,
132 GlobalTypeFromLocalType(original_layout, input_type));
133 TF_ASSIGN_OR_RETURN(const mlir::TensorType output_type,
134 LocalTypeFromGlobalType(desired_layout, global_type));
135
136 mlir::Location loc = DT_LOC2(original_value.getLoc(), "DTensorAllScatterOp");
137 mlir::TF::DTensorAllScatterOp all_scatter =
138 builder.create<mlir::TF::DTensorAllScatterOp>(
139 loc, output_type, original_value,
140 mlir::dtensor::LayoutAttr::get(builder.getContext(), original_layout),
141 mlir::dtensor::LayoutAttr::get(builder.getContext(), desired_layout));
142 SetSingleLayoutOnOp(all_scatter, desired_layout);
143
144 if (newly_created_ops != nullptr) newly_created_ops->insert(all_scatter);
145
146 return all_scatter.output();
147 }
148
EmitDenseToSparseToDense(mlir::OpBuilder & builder,mlir::Value input,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)149 StatusOr<mlir::Value> EmitDenseToSparseToDense(
150 mlir::OpBuilder& builder, mlir::Value input,
151 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
152 // First create a Dense To Sparse Op. Since there is no DenseToSparseOp,
153 // we do it manually by creating the indices, values, and shapes tensor
154 // through various ops.
155 //
156 // indices tensor = tf.where(tf.not_equal(input, tf.zeros_like(tensor)))
157 // values tensor = tf.gather_nd(input, indices)
158 // shape tensor = tf.shape(input)
159 mlir::TF::ZerosLikeOp zeros_like =
160 builder.create<mlir::TF::ZerosLikeOp>(input.getLoc(), input);
161 mlir::TF::NotEqualOp not_equal = builder.create<mlir::TF::NotEqualOp>(
162 zeros_like.getLoc(), input, zeros_like, builder.getBoolAttr(false));
163
164 mlir::TF::WhereOp indices = builder.create<mlir::TF::WhereOp>(
165 not_equal.getLoc(),
166 mlir::RankedTensorType::get(GetShapeOfValue(not_equal).ValueOrDie(),
167 builder.getI64Type()),
168 not_equal);
169
170 mlir::TF::GatherNdOp values = builder.create<mlir::TF::GatherNdOp>(
171 input.getLoc(), input.getType(), input, indices);
172 auto shape = builder.create<mlir::TF::ShapeOp>(input.getLoc(), input,
173 builder.getBoolAttr(false));
174
175 // Emit a SparseToDenseOp and replace the SparseTensor with the result of
176 // this new op.
177 auto zero_scalar = CreateZeroScalarConst(
178 builder, input.getLoc(),
179 input.getType().cast<mlir::TensorType>().getElementType());
180 if (!zero_scalar.has_value())
181 return errors::Internal("Failure in creating a zero scalar const");
182
183 auto dense = builder.create<mlir::TF::SparseToDenseOp>(
184 input.getLoc(), input.getType(),
185 mlir::ValueRange({indices, shape, values, zero_scalar.value()}));
186
187 if (newly_created_ops != nullptr) {
188 for (auto new_op : {dense.getOperation(), shape.getOperation(),
189 values.getOperation(), indices.getOperation(),
190 not_equal.getOperation(), zeros_like.getOperation()}) {
191 newly_created_ops->insert(new_op);
192 }
193 }
194
195 return dense.getResult();
196 }
197
EmitRelayout(mlir::Value input,const dtensor::Layout & src_layout,const dtensor::Layout & tgt_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)198 StatusOr<mlir::Value> EmitRelayout(
199 mlir::Value input, const dtensor::Layout& src_layout,
200 const dtensor::Layout& tgt_layout,
201 llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
202 // EmitRelayout is performed by doing a split, an AllGather and another split.
203 // The first split oppertunistically splits input tensor dimension i on mesh
204 // mesh axis x if:
205 // 1. tgt_layout contains x at position i
206 // 2. src_layout is unsharded at position i.
207 // 3. src_layout does not contain mesh axis x.
208 // This produces intermediate layout 1.
209 // Next an all concat is performed on any axis in the intermediate layout 1
210 // that does not agree with the sharding on the output axis.
211 // This produces intermediate layout 2.
212 // A split is performed from intermediate layout 2 to the tgt layout.
213
214 if (src_layout.IsEquivalent(tgt_layout)) return input;
215
216 // Save whether the input is from a SparseToDenseOp. If it is, then we will
217 // emit a DenseToSparse and a SparseToDense op.
218 bool is_sparse = IsSparseValue(input);
219 if (!input.getType().isa<mlir::RankedTensorType>())
220 return errors::Internal(
221 "attempting to relayout a tensor that does not "
222 "have a rank");
223
224 if (src_layout.mesh() != tgt_layout.mesh()) {
225 return errors::Internal("Attempted to relayout to a different mesh.");
226 }
227 if (src_layout.rank() != tgt_layout.rank()) {
228 return errors::Internal(
229 "Attempted to relayout to a different global shape.");
230 }
231
232 absl::flat_hash_set<std::string> src_sharding_dims;
233 for (int i = 0; i < src_layout.rank(); ++i)
234 src_sharding_dims.emplace(src_layout.sharding_spec(i));
235
236 std::vector<ShardingSpec> intermediate_specs_1(src_layout.rank());
237 for (int i = 0; i < src_layout.rank(); ++i) {
238 if (Layout::IsShardedSpec(tgt_layout.dim(i)) &&
239 !Layout::IsShardedSpec(src_layout.dim(i)) &&
240 !src_sharding_dims.contains(tgt_layout.sharding_spec(i)))
241 intermediate_specs_1[i] = tgt_layout.dim(i);
242 else
243 intermediate_specs_1[i] = src_layout.dim(i);
244 }
245 TF_ASSIGN_OR_RETURN(
246 Layout intermediate_layout_1,
247 Layout::GetLayout(intermediate_specs_1, src_layout.mesh()));
248
249 mlir::OpBuilder builder(input.getContext());
250 TF_RETURN_IF_ERROR(SetBuilderInsertionAfterValue(input, builder));
251
252 llvm::SmallPtrSet<mlir::Operation*, 4> local_newly_created_ops;
253 TF_ASSIGN_OR_RETURN(mlir::Value split_result,
254 EmitAllScatter(builder, input, src_layout,
255 intermediate_layout_1, newly_created_ops));
256
257 std::vector<ShardingSpec> intermediate_specs_2(src_layout.rank());
258 for (int i = 0; i < src_layout.rank(); ++i) {
259 if (Layout::IsShardedSpec(intermediate_specs_1[i]) &&
260 intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i))
261 intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim);
262 else
263 intermediate_specs_2[i] = intermediate_specs_1[i];
264 }
265 TF_ASSIGN_OR_RETURN(
266 Layout intermediate_layout_2,
267 Layout::GetLayout(intermediate_specs_2, src_layout.mesh()));
268
269 TF_ASSIGN_OR_RETURN(
270 mlir::Value concat_result,
271 EmitAllGather(builder, split_result, intermediate_layout_1,
272 intermediate_layout_2, newly_created_ops));
273
274 auto all_scatter =
275 EmitAllScatter(builder, concat_result, intermediate_layout_2, tgt_layout,
276 newly_created_ops);
277
278 if (!is_sparse) return all_scatter;
279 if (!all_scatter.ok()) return all_scatter;
280 return EmitDenseToSparseToDense(builder, all_scatter.ValueOrDie(),
281 newly_created_ops);
282 }
283
EmitAllReduce(mlir::OpBuilder & builder,const dtensor::Layout & output_layout,const absl::flat_hash_set<std::string> & reduced_dims,mlir::Operation * input,absl::string_view reduce_op)284 StatusOr<mlir::Operation*> EmitAllReduce(
285 mlir::OpBuilder& builder, const dtensor::Layout& output_layout,
286 const absl::flat_hash_set<std::string>& reduced_dims,
287 mlir::Operation* input, absl::string_view reduce_op) {
288 TF_ASSIGN_OR_RETURN(auto partitions, GetAllReducePartitionsFromReducedDims(
289 output_layout, reduced_dims));
290 const int32 num_partitions = partitions.size();
291
292 // If every device lives in its own partition, we don't need to emit a
293 // collective.
294 if (num_partitions == output_layout.num_devices()) {
295 return InferSPMDExpandedLocalShape(input);
296 }
297
298 // Construct a flattened list of reduce partitions. This will be converted
299 // into a 2-D const tensor for the DTensorAllReduce op.
300 std::vector<int32> partitions_flat;
301 for (auto& p : partitions) {
302 if (p.second.size() != partitions.begin()->second.size()) {
303 return errors::InvalidArgument(
304 "AllReduce partitions had different sizes -- this is not supported "
305 "in MLIR.");
306 }
307 partitions_flat.insert(partitions_flat.end(), p.second.begin(),
308 p.second.end());
309 }
310
311 int32 partition_size = partitions.begin()->second.size();
312 auto shaped_type = mlir::RankedTensorType::get(
313 {num_partitions, partition_size},
314 mlir::IntegerType::get(builder.getContext(), 32));
315 auto group_assignment =
316 mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat);
317
318 TF_ASSIGN_OR_RETURN(std::string device_type,
319 DeviceTypeFromMesh(output_layout.mesh()));
320
321 mlir::Location loc = DT_LOC2(input->getLoc(), "DTensorAllReduceOp");
322 auto all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
323 loc, input->getResultTypes()[0], input->getOpResult(0),
324 builder.create<mlir::TF::ConstOp>(loc, group_assignment),
325 builder.getStringAttr(std::string(reduce_op)),
326 builder.getStringAttr(device_type));
327 SetSingleLayoutOnOp(all_reduce, output_layout);
328 input->getOpResult(0).replaceAllUsesExcept(
329 all_reduce.getResult(),
330 llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce});
331 return all_reduce.getOperation();
332 }
333
334 namespace {
335
336 // Returns a offset multiplier to calculate device id / mesh coordinate.
GetMeshDimensionOffsetWithNeighbor(const Mesh & mesh,const std::string & mesh_dim)337 int GetMeshDimensionOffsetWithNeighbor(const Mesh& mesh,
338 const std::string& mesh_dim) {
339 const int index = mesh.GetMeshDimIndexWithName(mesh_dim);
340 const std::vector<int64_t> mesh_dim_sizes = mesh.dim_sizes();
341 int offset = 1;
342 for (int i = index + 1; i < mesh_dim_sizes.size(); ++i) {
343 offset = offset * mesh_dim_sizes[i];
344 }
345 return offset;
346 }
347
348 // Returns a mesh coordinate of mesh index with `mesh_dim_name` given
349 // `device_id`.
GetMeshCoordinateIndex(const Mesh & mesh,const std::string & mesh_dim_name,int device_id)350 StatusOr<int> GetMeshCoordinateIndex(const Mesh& mesh,
351 const std::string& mesh_dim_name,
352 int device_id) {
353 const int offset = GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim_name);
354 TF_ASSIGN_OR_RETURN(int64_t mesh_dim_size, mesh.dim_size(mesh_dim_name));
355
356 return (device_id / offset) % mesh_dim_size;
357 }
358
359 // Returns a 2D tensor array of size [N, 2] that specifies source target pair
360 // to be used for halo exchange.
CreateConstSrcTargetPair(const Mesh & mesh,const std::string & mesh_dim_name,bool shift_left,mlir::Location location,mlir::OpBuilder & builder)361 StatusOr<mlir::Value> CreateConstSrcTargetPair(const Mesh& mesh,
362 const std::string& mesh_dim_name,
363 bool shift_left,
364 mlir::Location location,
365 mlir::OpBuilder& builder) {
366 const int mesh_dim_index = mesh.GetMeshDimIndexWithName(mesh_dim_name);
367 const std::vector<MeshDimension> mesh_dimensions = mesh.dims();
368
369 llvm::SmallVector<int, 4> src_target_pair_flat;
370 src_target_pair_flat.reserve(mesh.local_device_ids().size() * 2);
371 for (const int local_device_id : mesh.local_device_ids()) {
372 // Calculate the mesh coordinate of the current local device id.
373 llvm::SmallVector<int, 4> mesh_coordinate_for_device_id;
374
375 for (const MeshDimension& mesh_dim : mesh_dimensions) {
376 TF_ASSIGN_OR_RETURN(
377 const int coordinate,
378 GetMeshCoordinateIndex(mesh, mesh_dim.name, local_device_id));
379
380 mesh_coordinate_for_device_id.push_back(coordinate);
381 }
382
383 // If mesh coordinate is on the left/right edge, then we conduct halo
384 // exchange with a processor which executes input block which represent
385 // `wrapped around` block.
386 const int mesh_coordinate = mesh_coordinate_for_device_id[mesh_dim_index];
387 TF_ASSIGN_OR_RETURN(const int dim_size, mesh.dim_size(mesh_dim_name));
388
389 // For tensor requiring halo exchange, we use collective permute.
390 const int src_device_id = local_device_id;
391 int target_device_id = 0;
392 for (const auto& data : llvm::enumerate(mesh_dimensions)) {
393 const MeshDimension& mesh_dim = data.value();
394 const int index = data.index();
395
396 int target_mesh_coordinate = 1;
397 if (mesh_dim.name == mesh_dim_name) {
398 target_mesh_coordinate =
399 shift_left ? mesh_coordinate - 1 : mesh_coordinate + 1;
400
401 // For processors executing input tensor on the left/right edges, target
402 // processor is the processor that executes wrapped around input block.
403 if (target_mesh_coordinate < 0 || target_mesh_coordinate >= dim_size)
404 target_mesh_coordinate =
405 (target_mesh_coordinate + dim_size) % dim_size;
406
407 } else {
408 target_mesh_coordinate = mesh_coordinate_for_device_id[index];
409 }
410
411 target_device_id +=
412 target_mesh_coordinate *
413 GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim.name);
414 }
415 src_target_pair_flat.push_back(src_device_id);
416 src_target_pair_flat.push_back(target_device_id);
417 }
418
419 const int num_pairs = src_target_pair_flat.size() / 2;
420 auto shaped_type = mlir::RankedTensorType::get(
421 {num_pairs, 2}, mlir::IntegerType::get(builder.getContext(), 32));
422
423 auto src_target_attr =
424 mlir::DenseIntElementsAttr::get(shaped_type, src_target_pair_flat);
425 mlir::Value src_target_pair_tensor =
426 builder.create<mlir::TF::ConstOp>(location, src_target_attr);
427 return src_target_pair_tensor;
428 }
429
430 } // namespace
431
EmitHaloExchange(mlir::OpBuilder & builder,int halo_size,const std::string & mesh_dim,const Layout & layout,mlir::Value mesh_coordinates,mlir::tf_device::ClusterOp cluster,mlir::Location location,mlir::Value tensor)432 StatusOr<mlir::Value> EmitHaloExchange(mlir::OpBuilder& builder, int halo_size,
433 const std::string& mesh_dim,
434 const Layout& layout,
435 mlir::Value mesh_coordinates,
436 mlir::tf_device::ClusterOp cluster,
437 mlir::Location location,
438 mlir::Value tensor) {
439 const Mesh& mesh = layout.mesh();
440
441 // Check mesh dimension requirements for halo exchange.
442 if (!mesh.IsMeshDim(mesh_dim))
443 return errors::InvalidArgument(
444 "Requested halo exchange on unknown mesh dim");
445
446 // TODO(hongjunchoi): Add support fof halo exchange for GPU/CPU.
447 if (!mesh.is_tpu_mesh())
448 return errors::InvalidArgument("Halo exchange is only supported on TPU.");
449
450 auto input_tensor_type = tensor.getType().dyn_cast<mlir::RankedTensorType>();
451 if (!input_tensor_type || !input_tensor_type.hasStaticShape())
452 return errors::InvalidArgument(
453 "Static shape of input tensor must be known for halo exchange.");
454
455 llvm::ArrayRef<int64_t> input_tensor_shape = input_tensor_type.getShape();
456 const std::vector<std::string> sharding_specs = layout.sharding_spec_strs();
457 const int split_dim_index = std::distance(
458 sharding_specs.begin(), llvm::find(sharding_specs, mesh_dim));
459
460 if (input_tensor_shape[split_dim_index] < halo_size)
461 return errors::InvalidArgument(
462 "For halo exhange, input shard tensor size of each processor must be "
463 "greater than halo size");
464
465 TF_ASSIGN_OR_RETURN(const int mesh_dim_index, mesh.idx_for_dim(mesh_dim));
466
467 TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate,
468 SelectScalarValueFromArray(builder, mesh_dim_index,
469 location, mesh_coordinates));
470
471 llvm::SmallVector<int64_t, 4> halo_exchange_tensor_shape;
472 for (const auto& size_and_index : llvm::enumerate(input_tensor_shape)) {
473 const int index = size_and_index.index();
474 const int size = size_and_index.value();
475 halo_exchange_tensor_shape.push_back(index == split_dim_index ? halo_size
476 : size);
477 }
478
479 // Find the halo tensor value to pad on the `left` side. Note that halo
480 // exchange can happen on top/bottom/left/right sides of a spatially
481 // partitioned tensor. However, we use `left`/`right` as the
482 // direction is implicit based on mesh dimension.
483 //
484 // For example, if mesh dimension splits the input tensor by its height
485 // dimension, then `left` actually means tensor to pad on the top side.
486 mlir::Value is_on_left_edge = builder.create<mlir::TF::EqualOp>(
487 location, CreateIntScalarConst(0, builder, location, /*use_int64=*/false),
488 scalar_mesh_coordinate, builder.getBoolAttr(true));
489
490 TF_ASSIGN_OR_RETURN(const int mesh_dim_size, mesh.dim_size(mesh_dim));
491 mlir::Value is_on_right_edge = builder.create<mlir::TF::EqualOp>(
492 location,
493 CreateIntScalarConst(mesh_dim_size - 1, builder, location,
494 /*use_int64=*/false),
495 scalar_mesh_coordinate, builder.getBoolAttr(true));
496
497 // Create zero ghost tensor to pad on left side.
498 mlir::RankedTensorType halo_tensor_type = mlir::RankedTensorType::get(
499 halo_exchange_tensor_shape, input_tensor_type.getElementType());
500 auto halo_type = mlir::RankedTensorType::get(
501 halo_tensor_type.getShape(), input_tensor_type.getElementType());
502
503 mlir::Attribute const_attr;
504 if (halo_type.getElementType().isIntOrIndex()) {
505 const_attr =
506 mlir::DenseIntElementsAttr::get(halo_type, llvm::SmallVector<int>{0});
507 } else {
508 const_attr =
509 mlir::DenseFPElementsAttr::get(halo_type, llvm::SmallVector<float>{0});
510 }
511
512 mlir::Value ghost_tensor_left =
513 builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
514
515 // Get the right side slice of the input tensor to pad on left side.
516 llvm::SmallVector<int64_t, 4> begin_left(layout.rank(), 0);
517 begin_left[split_dim_index] = input_tensor_shape[split_dim_index] - halo_size;
518 mlir::Value begin_tensor_left =
519 ops_util::GetR1Const(begin_left, builder, location);
520
521 llvm::SmallVector<int64_t, 4> size(input_tensor_shape.begin(),
522 input_tensor_shape.end());
523 size[split_dim_index] = halo_size;
524
525 mlir::Value size_tensor_left = ops_util::GetR1Const(size, builder, location);
526 mlir::Value sliced_tensor_left = builder.create<mlir::TF::SliceOp>(
527 location, halo_type, tensor, begin_tensor_left, size_tensor_left);
528
529 mlir::Value halo_tensor_left = builder.create<mlir::TF::SelectV2Op>(
530 location, is_on_right_edge, ghost_tensor_left, sliced_tensor_left);
531
532 // Invoke collective permute to receive the tensor from neighboring processor.
533 // Halo slices from the left neighbor are received on each processor (they
534 // are shifted right).
535 TF_ASSIGN_OR_RETURN(
536 mlir::Value src_target_pair_left,
537 CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/false, location,
538 builder));
539
540 mlir::Value left_concat_value = builder.create<mlir::TF::CollectivePermuteOp>(
541 location, sliced_tensor_left.getType(), halo_tensor_left,
542 src_target_pair_left);
543
544 mlir::Value ghost_tensor_right =
545 builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
546
547 // Else, values to pad is tensor from different processor. We use collective
548 // permute to access tensor slice from another device.
549 // Get the left side slice of the input tensor.
550 llvm::SmallVector<int64_t, 4> begin_right(layout.rank(), 0);
551 mlir::Value begin_tensor_right =
552 ops_util::GetR1Const(begin_right, builder, location);
553 mlir::Value size_tensor_right = ops_util::GetR1Const(size, builder, location);
554 mlir::Value sliced_tensor_right = builder.create<mlir::TF::SliceOp>(
555 location, halo_type, tensor, begin_tensor_right, size_tensor_right);
556
557 // Find the halo tensor value to pad on the `right` side.
558 // If input block is on the right edge, we use zero ghost tensor instead.
559 mlir::Value halo_tensor_right = builder.create<mlir::TF::SelectV2Op>(
560 location, is_on_left_edge, ghost_tensor_right, sliced_tensor_right);
561
562 // Invoke collective permute to receive the tensor from neighboring processor.
563 // Halo slices from the right neighbor are received on each processor (they
564 // are shifted left).
565 TF_ASSIGN_OR_RETURN(
566 mlir::Value src_target_pair_right,
567 CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/true, location,
568 builder));
569 mlir::Value right_concat_value =
570 builder.create<mlir::TF::CollectivePermuteOp>(
571 location, sliced_tensor_right.getType(), halo_tensor_right,
572 src_target_pair_right);
573
574 // Final halo exchanged value is concatenated value of left_concat_value,
575 // tensor, and right_concat_value in the mesh_dimension.
576 llvm::SmallVector<int64_t, 4> final_shape(input_tensor_shape.begin(),
577 input_tensor_shape.end());
578 final_shape[split_dim_index] = final_shape[split_dim_index] + 2 * halo_size;
579
580 auto final_type = mlir::RankedTensorType::get(
581 final_shape, input_tensor_type.getElementType());
582 mlir::Value concat_axis =
583 CreateIntScalarConst(split_dim_index, builder, location);
584 mlir::Value final_value = builder.create<mlir::TF::ConcatV2Op>(
585 location, final_type,
586 llvm::SmallVector<mlir::Value, 4>{left_concat_value, tensor,
587 right_concat_value},
588 concat_axis);
589
590 return final_value;
591 }
592
593 } // namespace dtensor
594 } // namespace tensorflow
595