xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/collectives.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/collectives.h"
17 
18 #include <cstdint>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/string_view.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/dtensor/cc/dstatus.h"
31 #include "tensorflow/dtensor/cc/tensor_layout.h"
32 #include "tensorflow/dtensor/mlir/collectives_common.h"
33 #include "tensorflow/dtensor/mlir/dtensor_location.h"
34 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35 #include "tensorflow/dtensor/mlir/layout_parsing.h"
36 #include "tensorflow/dtensor/mlir/shape_utils.h"
37 #include "tensorflow/dtensor/mlir/sparse_expander_common.h"
38 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
39 #include "tensorflow/dtensor/mlir/value_utils.h"
40 
41 namespace tensorflow {
42 namespace dtensor {
43 
44 namespace {
45 
46 namespace ops_util = ::mlir::TF::collection_ops_util;
47 
48 }  // namespace
49 
EmitAllGather(mlir::OpBuilder & builder,mlir::Value input,const dtensor::Layout & src_layout,const dtensor::Layout & tgt_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)50 StatusOr<mlir::Value> EmitAllGather(
51     mlir::OpBuilder& builder, mlir::Value input,
52     const dtensor::Layout& src_layout, const dtensor::Layout& tgt_layout,
53     llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
54   if (src_layout.IsEquivalent(tgt_layout)) return input;
55 
56   if (src_layout.rank() != tgt_layout.rank()) {
57     return errors::InvalidArgument(
58         "Expected source and target layout to have the same rank, got ",
59         src_layout.rank(), " vs ", tgt_layout.rank());
60   }
61 
62   // Check that the tgt_layout is less sharded then src_layout.
63   for (int i = 0; i < src_layout.rank(); ++i) {
64     if (src_layout.sharding_spec(i) != tgt_layout.sharding_spec(i) &&
65         Layout::IsShardedDimension(tgt_layout.sharding_spec(i))) {
66       return errors::InvalidArgument("source layout (", src_layout.ToString(),
67                                      ") for all gather is not less sharded "
68                                      "than the target layout (",
69                                      tgt_layout.ToString());
70     }
71   }
72 
73   // For convenience, operate on explicit input shapes. This isn't necessary,
74   // as we could instead generate operations on top of the dynamic shape.
75   const mlir::TensorType input_type =
76       input.getType().dyn_cast<mlir::TensorType>();
77   if (!input_type) {
78     return errors::Internal(
79         llvm::formatv(
80             "Cannot cast input_type : {0} to TensorType. Shape must be "
81             " statically known before emitting AllGather. This should not "
82             "happen as we already cast it when getting its shape.",
83             input.getType())
84             .str());
85   }
86 
87   TF_ASSIGN_OR_RETURN(mlir::TensorType global_type,
88                       GlobalTypeFromLocalType(src_layout, input_type));
89   TF_ASSIGN_OR_RETURN(mlir::TensorType output_type,
90                       LocalTypeFromGlobalType(tgt_layout, global_type));
91 
92   mlir::Location loc = DT_LOC2(input.getLoc(), "DTensorAllGatherOp");
93   mlir::TF::DTensorAllGatherOp all_gather =
94       builder.create<mlir::TF::DTensorAllGatherOp>(
95           loc, output_type, input,
96           mlir::dtensor::LayoutAttr::get(builder.getContext(), src_layout),
97           mlir::dtensor::LayoutAttr::get(builder.getContext(), tgt_layout));
98   SetSingleLayoutOnOp(all_gather, tgt_layout);
99 
100   if (newly_created_ops != nullptr) newly_created_ops->insert(all_gather);
101 
102   return all_gather.output();
103 }
104 
EmitAllScatter(mlir::OpBuilder & builder,const mlir::Value & original_value,const Layout & original_layout,const Layout & desired_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)105 StatusOr<const mlir::Value> EmitAllScatter(
106     mlir::OpBuilder& builder, const mlir::Value& original_value,
107     const Layout& original_layout, const Layout& desired_layout,
108     llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
109   if (original_layout.IsEquivalent(desired_layout)) return original_value;
110 
111   // Have an early return if desired layout is not more sharded then the
112   // original_layout.
113   assert(original_layout.rank() == desired_layout.rank());
114   for (int i = 0; i < original_layout.rank(); ++i) {
115     if (original_layout.sharding_spec(i) != desired_layout.sharding_spec(i) &&
116         Layout::IsShardedDimension(original_layout.sharding_spec(i))) {
117       return errors::InvalidArgument(
118           "EmitAllScatter was passed a desired_layout ",
119           desired_layout.ToString(),
120           " which was not more sharded than the original_layout ",
121           original_layout.ToString());
122     }
123   }
124 
125   const mlir::TensorType input_type =
126       original_value.getType().dyn_cast<mlir::TensorType>();
127   if (!input_type)
128     return errors::InvalidArgument(
129         "input to EmitAllScatter does not have a TensorType");
130 
131   TF_ASSIGN_OR_RETURN(const mlir::TensorType global_type,
132                       GlobalTypeFromLocalType(original_layout, input_type));
133   TF_ASSIGN_OR_RETURN(const mlir::TensorType output_type,
134                       LocalTypeFromGlobalType(desired_layout, global_type));
135 
136   mlir::Location loc = DT_LOC2(original_value.getLoc(), "DTensorAllScatterOp");
137   mlir::TF::DTensorAllScatterOp all_scatter =
138       builder.create<mlir::TF::DTensorAllScatterOp>(
139           loc, output_type, original_value,
140           mlir::dtensor::LayoutAttr::get(builder.getContext(), original_layout),
141           mlir::dtensor::LayoutAttr::get(builder.getContext(), desired_layout));
142   SetSingleLayoutOnOp(all_scatter, desired_layout);
143 
144   if (newly_created_ops != nullptr) newly_created_ops->insert(all_scatter);
145 
146   return all_scatter.output();
147 }
148 
EmitDenseToSparseToDense(mlir::OpBuilder & builder,mlir::Value input,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)149 StatusOr<mlir::Value> EmitDenseToSparseToDense(
150     mlir::OpBuilder& builder, mlir::Value input,
151     llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
152   // First create a Dense To Sparse Op. Since there is no DenseToSparseOp,
153   // we do it manually by creating the indices, values, and shapes tensor
154   // through various ops.
155   //
156   // indices tensor = tf.where(tf.not_equal(input, tf.zeros_like(tensor)))
157   // values tensor = tf.gather_nd(input, indices)
158   // shape tensor = tf.shape(input)
159   mlir::TF::ZerosLikeOp zeros_like =
160       builder.create<mlir::TF::ZerosLikeOp>(input.getLoc(), input);
161   mlir::TF::NotEqualOp not_equal = builder.create<mlir::TF::NotEqualOp>(
162       zeros_like.getLoc(), input, zeros_like, builder.getBoolAttr(false));
163 
164   mlir::TF::WhereOp indices = builder.create<mlir::TF::WhereOp>(
165       not_equal.getLoc(),
166       mlir::RankedTensorType::get(GetShapeOfValue(not_equal).ValueOrDie(),
167                                   builder.getI64Type()),
168       not_equal);
169 
170   mlir::TF::GatherNdOp values = builder.create<mlir::TF::GatherNdOp>(
171       input.getLoc(), input.getType(), input, indices);
172   auto shape = builder.create<mlir::TF::ShapeOp>(input.getLoc(), input,
173                                                  builder.getBoolAttr(false));
174 
175   // Emit a SparseToDenseOp and replace the SparseTensor with the result of
176   // this new op.
177   auto zero_scalar = CreateZeroScalarConst(
178       builder, input.getLoc(),
179       input.getType().cast<mlir::TensorType>().getElementType());
180   if (!zero_scalar.has_value())
181     return errors::Internal("Failure in creating a zero scalar const");
182 
183   auto dense = builder.create<mlir::TF::SparseToDenseOp>(
184       input.getLoc(), input.getType(),
185       mlir::ValueRange({indices, shape, values, zero_scalar.value()}));
186 
187   if (newly_created_ops != nullptr) {
188     for (auto new_op : {dense.getOperation(), shape.getOperation(),
189                         values.getOperation(), indices.getOperation(),
190                         not_equal.getOperation(), zeros_like.getOperation()}) {
191       newly_created_ops->insert(new_op);
192     }
193   }
194 
195   return dense.getResult();
196 }
197 
EmitRelayout(mlir::Value input,const dtensor::Layout & src_layout,const dtensor::Layout & tgt_layout,llvm::SmallPtrSet<mlir::Operation *,4> * newly_created_ops)198 StatusOr<mlir::Value> EmitRelayout(
199     mlir::Value input, const dtensor::Layout& src_layout,
200     const dtensor::Layout& tgt_layout,
201     llvm::SmallPtrSet<mlir::Operation*, 4>* newly_created_ops) {
202   // EmitRelayout is performed by doing a split, an AllGather and another split.
203   // The first split oppertunistically splits input tensor dimension i on mesh
204   // mesh axis x if:
205   // 1.  tgt_layout contains x at position i
206   // 2.  src_layout is unsharded at position i.
207   // 3.  src_layout does not contain mesh axis x.
208   // This produces intermediate layout 1.
209   // Next an all concat is performed on any axis in the intermediate layout 1
210   // that does not agree with the sharding on the output axis.
211   // This produces intermediate layout 2.
212   // A split is performed from intermediate layout 2 to the tgt layout.
213 
214   if (src_layout.IsEquivalent(tgt_layout)) return input;
215 
216   // Save whether the input is from a SparseToDenseOp. If it is, then we will
217   // emit a DenseToSparse and a SparseToDense op.
218   bool is_sparse = IsSparseValue(input);
219   if (!input.getType().isa<mlir::RankedTensorType>())
220     return errors::Internal(
221         "attempting to relayout a tensor that does not "
222         "have a rank");
223 
224   if (src_layout.mesh() != tgt_layout.mesh()) {
225     return errors::Internal("Attempted to relayout to a different mesh.");
226   }
227   if (src_layout.rank() != tgt_layout.rank()) {
228     return errors::Internal(
229         "Attempted to relayout to a different global shape.");
230   }
231 
232   absl::flat_hash_set<std::string> src_sharding_dims;
233   for (int i = 0; i < src_layout.rank(); ++i)
234     src_sharding_dims.emplace(src_layout.sharding_spec(i));
235 
236   std::vector<ShardingSpec> intermediate_specs_1(src_layout.rank());
237   for (int i = 0; i < src_layout.rank(); ++i) {
238     if (Layout::IsShardedSpec(tgt_layout.dim(i)) &&
239         !Layout::IsShardedSpec(src_layout.dim(i)) &&
240         !src_sharding_dims.contains(tgt_layout.sharding_spec(i)))
241       intermediate_specs_1[i] = tgt_layout.dim(i);
242     else
243       intermediate_specs_1[i] = src_layout.dim(i);
244   }
245   TF_ASSIGN_OR_RETURN(
246       Layout intermediate_layout_1,
247       Layout::GetLayout(intermediate_specs_1, src_layout.mesh()));
248 
249   mlir::OpBuilder builder(input.getContext());
250   TF_RETURN_IF_ERROR(SetBuilderInsertionAfterValue(input, builder));
251 
252   llvm::SmallPtrSet<mlir::Operation*, 4> local_newly_created_ops;
253   TF_ASSIGN_OR_RETURN(mlir::Value split_result,
254                       EmitAllScatter(builder, input, src_layout,
255                                      intermediate_layout_1, newly_created_ops));
256 
257   std::vector<ShardingSpec> intermediate_specs_2(src_layout.rank());
258   for (int i = 0; i < src_layout.rank(); ++i) {
259     if (Layout::IsShardedSpec(intermediate_specs_1[i]) &&
260         intermediate_specs_1[i].sharding_spec() != tgt_layout.sharding_spec(i))
261       intermediate_specs_2[i].set_sharding_spec(Layout::kUnshardedDim);
262     else
263       intermediate_specs_2[i] = intermediate_specs_1[i];
264   }
265   TF_ASSIGN_OR_RETURN(
266       Layout intermediate_layout_2,
267       Layout::GetLayout(intermediate_specs_2, src_layout.mesh()));
268 
269   TF_ASSIGN_OR_RETURN(
270       mlir::Value concat_result,
271       EmitAllGather(builder, split_result, intermediate_layout_1,
272                     intermediate_layout_2, newly_created_ops));
273 
274   auto all_scatter =
275       EmitAllScatter(builder, concat_result, intermediate_layout_2, tgt_layout,
276                      newly_created_ops);
277 
278   if (!is_sparse) return all_scatter;
279   if (!all_scatter.ok()) return all_scatter;
280   return EmitDenseToSparseToDense(builder, all_scatter.ValueOrDie(),
281                                   newly_created_ops);
282 }
283 
EmitAllReduce(mlir::OpBuilder & builder,const dtensor::Layout & output_layout,const absl::flat_hash_set<std::string> & reduced_dims,mlir::Operation * input,absl::string_view reduce_op)284 StatusOr<mlir::Operation*> EmitAllReduce(
285     mlir::OpBuilder& builder, const dtensor::Layout& output_layout,
286     const absl::flat_hash_set<std::string>& reduced_dims,
287     mlir::Operation* input, absl::string_view reduce_op) {
288   TF_ASSIGN_OR_RETURN(auto partitions, GetAllReducePartitionsFromReducedDims(
289                                            output_layout, reduced_dims));
290   const int32 num_partitions = partitions.size();
291 
292   // If every device lives in its own partition, we don't need to emit a
293   // collective.
294   if (num_partitions == output_layout.num_devices()) {
295     return InferSPMDExpandedLocalShape(input);
296   }
297 
298   // Construct a flattened list of reduce partitions. This will be converted
299   // into a 2-D const tensor for the DTensorAllReduce op.
300   std::vector<int32> partitions_flat;
301   for (auto& p : partitions) {
302     if (p.second.size() != partitions.begin()->second.size()) {
303       return errors::InvalidArgument(
304           "AllReduce partitions had different sizes -- this is not supported "
305           "in MLIR.");
306     }
307     partitions_flat.insert(partitions_flat.end(), p.second.begin(),
308                            p.second.end());
309   }
310 
311   int32 partition_size = partitions.begin()->second.size();
312   auto shaped_type = mlir::RankedTensorType::get(
313       {num_partitions, partition_size},
314       mlir::IntegerType::get(builder.getContext(), 32));
315   auto group_assignment =
316       mlir::DenseIntElementsAttr::get(shaped_type, partitions_flat);
317 
318   TF_ASSIGN_OR_RETURN(std::string device_type,
319                       DeviceTypeFromMesh(output_layout.mesh()));
320 
321   mlir::Location loc = DT_LOC2(input->getLoc(), "DTensorAllReduceOp");
322   auto all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
323       loc, input->getResultTypes()[0], input->getOpResult(0),
324       builder.create<mlir::TF::ConstOp>(loc, group_assignment),
325       builder.getStringAttr(std::string(reduce_op)),
326       builder.getStringAttr(device_type));
327   SetSingleLayoutOnOp(all_reduce, output_layout);
328   input->getOpResult(0).replaceAllUsesExcept(
329       all_reduce.getResult(),
330       llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce});
331   return all_reduce.getOperation();
332 }
333 
334 namespace {
335 
336 // Returns a offset multiplier to calculate device id / mesh coordinate.
GetMeshDimensionOffsetWithNeighbor(const Mesh & mesh,const std::string & mesh_dim)337 int GetMeshDimensionOffsetWithNeighbor(const Mesh& mesh,
338                                        const std::string& mesh_dim) {
339   const int index = mesh.GetMeshDimIndexWithName(mesh_dim);
340   const std::vector<int64_t> mesh_dim_sizes = mesh.dim_sizes();
341   int offset = 1;
342   for (int i = index + 1; i < mesh_dim_sizes.size(); ++i) {
343     offset = offset * mesh_dim_sizes[i];
344   }
345   return offset;
346 }
347 
348 // Returns a mesh coordinate of mesh index with `mesh_dim_name` given
349 // `device_id`.
GetMeshCoordinateIndex(const Mesh & mesh,const std::string & mesh_dim_name,int device_id)350 StatusOr<int> GetMeshCoordinateIndex(const Mesh& mesh,
351                                      const std::string& mesh_dim_name,
352                                      int device_id) {
353   const int offset = GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim_name);
354   TF_ASSIGN_OR_RETURN(int64_t mesh_dim_size, mesh.dim_size(mesh_dim_name));
355 
356   return (device_id / offset) % mesh_dim_size;
357 }
358 
359 // Returns a 2D tensor array of size [N, 2] that specifies source target pair
360 // to be used for halo exchange.
CreateConstSrcTargetPair(const Mesh & mesh,const std::string & mesh_dim_name,bool shift_left,mlir::Location location,mlir::OpBuilder & builder)361 StatusOr<mlir::Value> CreateConstSrcTargetPair(const Mesh& mesh,
362                                                const std::string& mesh_dim_name,
363                                                bool shift_left,
364                                                mlir::Location location,
365                                                mlir::OpBuilder& builder) {
366   const int mesh_dim_index = mesh.GetMeshDimIndexWithName(mesh_dim_name);
367   const std::vector<MeshDimension> mesh_dimensions = mesh.dims();
368 
369   llvm::SmallVector<int, 4> src_target_pair_flat;
370   src_target_pair_flat.reserve(mesh.local_device_ids().size() * 2);
371   for (const int local_device_id : mesh.local_device_ids()) {
372     // Calculate the mesh coordinate of the current local device id.
373     llvm::SmallVector<int, 4> mesh_coordinate_for_device_id;
374 
375     for (const MeshDimension& mesh_dim : mesh_dimensions) {
376       TF_ASSIGN_OR_RETURN(
377           const int coordinate,
378           GetMeshCoordinateIndex(mesh, mesh_dim.name, local_device_id));
379 
380       mesh_coordinate_for_device_id.push_back(coordinate);
381     }
382 
383     // If mesh coordinate is on the left/right edge, then we conduct halo
384     // exchange with a processor which executes input block which represent
385     // `wrapped around` block.
386     const int mesh_coordinate = mesh_coordinate_for_device_id[mesh_dim_index];
387     TF_ASSIGN_OR_RETURN(const int dim_size, mesh.dim_size(mesh_dim_name));
388 
389     // For tensor requiring halo exchange, we use collective permute.
390     const int src_device_id = local_device_id;
391     int target_device_id = 0;
392     for (const auto& data : llvm::enumerate(mesh_dimensions)) {
393       const MeshDimension& mesh_dim = data.value();
394       const int index = data.index();
395 
396       int target_mesh_coordinate = 1;
397       if (mesh_dim.name == mesh_dim_name) {
398         target_mesh_coordinate =
399             shift_left ? mesh_coordinate - 1 : mesh_coordinate + 1;
400 
401         // For processors executing input tensor on the left/right edges, target
402         // processor is the processor that executes wrapped around input block.
403         if (target_mesh_coordinate < 0 || target_mesh_coordinate >= dim_size)
404           target_mesh_coordinate =
405               (target_mesh_coordinate + dim_size) % dim_size;
406 
407       } else {
408         target_mesh_coordinate = mesh_coordinate_for_device_id[index];
409       }
410 
411       target_device_id +=
412           target_mesh_coordinate *
413           GetMeshDimensionOffsetWithNeighbor(mesh, mesh_dim.name);
414     }
415     src_target_pair_flat.push_back(src_device_id);
416     src_target_pair_flat.push_back(target_device_id);
417   }
418 
419   const int num_pairs = src_target_pair_flat.size() / 2;
420   auto shaped_type = mlir::RankedTensorType::get(
421       {num_pairs, 2}, mlir::IntegerType::get(builder.getContext(), 32));
422 
423   auto src_target_attr =
424       mlir::DenseIntElementsAttr::get(shaped_type, src_target_pair_flat);
425   mlir::Value src_target_pair_tensor =
426       builder.create<mlir::TF::ConstOp>(location, src_target_attr);
427   return src_target_pair_tensor;
428 }
429 
430 }  // namespace
431 
EmitHaloExchange(mlir::OpBuilder & builder,int halo_size,const std::string & mesh_dim,const Layout & layout,mlir::Value mesh_coordinates,mlir::tf_device::ClusterOp cluster,mlir::Location location,mlir::Value tensor)432 StatusOr<mlir::Value> EmitHaloExchange(mlir::OpBuilder& builder, int halo_size,
433                                        const std::string& mesh_dim,
434                                        const Layout& layout,
435                                        mlir::Value mesh_coordinates,
436                                        mlir::tf_device::ClusterOp cluster,
437                                        mlir::Location location,
438                                        mlir::Value tensor) {
439   const Mesh& mesh = layout.mesh();
440 
441   // Check mesh dimension requirements for halo exchange.
442   if (!mesh.IsMeshDim(mesh_dim))
443     return errors::InvalidArgument(
444         "Requested halo exchange on unknown mesh dim");
445 
446   // TODO(hongjunchoi): Add support fof halo exchange for GPU/CPU.
447   if (!mesh.is_tpu_mesh())
448     return errors::InvalidArgument("Halo exchange is only supported on TPU.");
449 
450   auto input_tensor_type = tensor.getType().dyn_cast<mlir::RankedTensorType>();
451   if (!input_tensor_type || !input_tensor_type.hasStaticShape())
452     return errors::InvalidArgument(
453         "Static shape of input tensor must be known for halo exchange.");
454 
455   llvm::ArrayRef<int64_t> input_tensor_shape = input_tensor_type.getShape();
456   const std::vector<std::string> sharding_specs = layout.sharding_spec_strs();
457   const int split_dim_index = std::distance(
458       sharding_specs.begin(), llvm::find(sharding_specs, mesh_dim));
459 
460   if (input_tensor_shape[split_dim_index] < halo_size)
461     return errors::InvalidArgument(
462         "For halo exhange, input shard tensor size of each processor must be "
463         "greater than halo size");
464 
465   TF_ASSIGN_OR_RETURN(const int mesh_dim_index, mesh.idx_for_dim(mesh_dim));
466 
467   TF_ASSIGN_OR_RETURN(mlir::Value scalar_mesh_coordinate,
468                       SelectScalarValueFromArray(builder, mesh_dim_index,
469                                                  location, mesh_coordinates));
470 
471   llvm::SmallVector<int64_t, 4> halo_exchange_tensor_shape;
472   for (const auto& size_and_index : llvm::enumerate(input_tensor_shape)) {
473     const int index = size_and_index.index();
474     const int size = size_and_index.value();
475     halo_exchange_tensor_shape.push_back(index == split_dim_index ? halo_size
476                                                                   : size);
477   }
478 
479   // Find the halo tensor value to pad on the `left` side. Note that halo
480   // exchange can happen on top/bottom/left/right sides of a spatially
481   // partitioned tensor. However, we use `left`/`right` as the
482   // direction is implicit based on mesh dimension.
483   //
484   // For example, if mesh dimension splits the input tensor by its height
485   // dimension, then `left` actually means tensor to pad on the top side.
486   mlir::Value is_on_left_edge = builder.create<mlir::TF::EqualOp>(
487       location, CreateIntScalarConst(0, builder, location, /*use_int64=*/false),
488       scalar_mesh_coordinate, builder.getBoolAttr(true));
489 
490   TF_ASSIGN_OR_RETURN(const int mesh_dim_size, mesh.dim_size(mesh_dim));
491   mlir::Value is_on_right_edge = builder.create<mlir::TF::EqualOp>(
492       location,
493       CreateIntScalarConst(mesh_dim_size - 1, builder, location,
494                            /*use_int64=*/false),
495       scalar_mesh_coordinate, builder.getBoolAttr(true));
496 
497   // Create zero ghost tensor to pad on left side.
498   mlir::RankedTensorType halo_tensor_type = mlir::RankedTensorType::get(
499       halo_exchange_tensor_shape, input_tensor_type.getElementType());
500   auto halo_type = mlir::RankedTensorType::get(
501       halo_tensor_type.getShape(), input_tensor_type.getElementType());
502 
503   mlir::Attribute const_attr;
504   if (halo_type.getElementType().isIntOrIndex()) {
505     const_attr =
506         mlir::DenseIntElementsAttr::get(halo_type, llvm::SmallVector<int>{0});
507   } else {
508     const_attr =
509         mlir::DenseFPElementsAttr::get(halo_type, llvm::SmallVector<float>{0});
510   }
511 
512   mlir::Value ghost_tensor_left =
513       builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
514 
515   // Get the right side slice of the input tensor to pad on left side.
516   llvm::SmallVector<int64_t, 4> begin_left(layout.rank(), 0);
517   begin_left[split_dim_index] = input_tensor_shape[split_dim_index] - halo_size;
518   mlir::Value begin_tensor_left =
519       ops_util::GetR1Const(begin_left, builder, location);
520 
521   llvm::SmallVector<int64_t, 4> size(input_tensor_shape.begin(),
522                                      input_tensor_shape.end());
523   size[split_dim_index] = halo_size;
524 
525   mlir::Value size_tensor_left = ops_util::GetR1Const(size, builder, location);
526   mlir::Value sliced_tensor_left = builder.create<mlir::TF::SliceOp>(
527       location, halo_type, tensor, begin_tensor_left, size_tensor_left);
528 
529   mlir::Value halo_tensor_left = builder.create<mlir::TF::SelectV2Op>(
530       location, is_on_right_edge, ghost_tensor_left, sliced_tensor_left);
531 
532   // Invoke collective permute to receive the tensor from neighboring processor.
533   // Halo slices from the left neighbor are received on each processor (they
534   // are shifted right).
535   TF_ASSIGN_OR_RETURN(
536       mlir::Value src_target_pair_left,
537       CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/false, location,
538                                builder));
539 
540   mlir::Value left_concat_value = builder.create<mlir::TF::CollectivePermuteOp>(
541       location, sliced_tensor_left.getType(), halo_tensor_left,
542       src_target_pair_left);
543 
544   mlir::Value ghost_tensor_right =
545       builder.create<mlir::TF::ConstOp>(location, const_attr).getResult();
546 
547   // Else, values to pad is tensor from different processor. We use collective
548   // permute to access tensor slice from another device.
549   // Get the left side slice of the input tensor.
550   llvm::SmallVector<int64_t, 4> begin_right(layout.rank(), 0);
551   mlir::Value begin_tensor_right =
552       ops_util::GetR1Const(begin_right, builder, location);
553   mlir::Value size_tensor_right = ops_util::GetR1Const(size, builder, location);
554   mlir::Value sliced_tensor_right = builder.create<mlir::TF::SliceOp>(
555       location, halo_type, tensor, begin_tensor_right, size_tensor_right);
556 
557   // Find the halo tensor value to pad on the `right` side.
558   // If input block is on the right edge, we use zero ghost tensor instead.
559   mlir::Value halo_tensor_right = builder.create<mlir::TF::SelectV2Op>(
560       location, is_on_left_edge, ghost_tensor_right, sliced_tensor_right);
561 
562   // Invoke collective permute to receive the tensor from neighboring processor.
563   // Halo slices from the right neighbor are received on each processor (they
564   // are shifted left).
565   TF_ASSIGN_OR_RETURN(
566       mlir::Value src_target_pair_right,
567       CreateConstSrcTargetPair(mesh, mesh_dim, /*shift_left=*/true, location,
568                                builder));
569   mlir::Value right_concat_value =
570       builder.create<mlir::TF::CollectivePermuteOp>(
571           location, sliced_tensor_right.getType(), halo_tensor_right,
572           src_target_pair_right);
573 
574   // Final halo exchanged value is concatenated value of left_concat_value,
575   // tensor, and right_concat_value in the mesh_dimension.
576   llvm::SmallVector<int64_t, 4> final_shape(input_tensor_shape.begin(),
577                                             input_tensor_shape.end());
578   final_shape[split_dim_index] = final_shape[split_dim_index] + 2 * halo_size;
579 
580   auto final_type = mlir::RankedTensorType::get(
581       final_shape, input_tensor_type.getElementType());
582   mlir::Value concat_axis =
583       CreateIntScalarConst(split_dim_index, builder, location);
584   mlir::Value final_value = builder.create<mlir::TF::ConcatV2Op>(
585       location, final_type,
586       llvm::SmallVector<mlir::Value, 4>{left_concat_value, tensor,
587                                         right_concat_value},
588       concat_axis);
589 
590   return final_value;
591 }
592 
593 }  // namespace dtensor
594 }  // namespace tensorflow
595