xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/slice_spmd_expander.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/dtensor/cc/dstatus.h"
29 #include "tensorflow/dtensor/mlir/collectives.h"
30 #include "tensorflow/dtensor/mlir/layout_parsing.h"
31 #include "tensorflow/dtensor/mlir/shape_utils.h"
32 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
33 #include "tensorflow/dtensor/mlir/value_utils.h"
34 #include "tensorflow/dtensor/proto/layout.pb.h"
35 
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39 
GetSliceOpArguments(mlir::TF::SliceOp slice_op,llvm::SmallVector<int64_t,4> & begins,bool & dynamic_begins,llvm::SmallVector<int64_t,4> & sizes)40 Status GetSliceOpArguments(mlir::TF::SliceOp slice_op,
41                            llvm::SmallVector<int64_t, 4>& begins,
42                            bool& dynamic_begins,
43                            llvm::SmallVector<int64_t, 4>& sizes) {
44   Status begins_result = ExtractConstVectorFromValue(slice_op.begin(), &begins);
45   dynamic_begins = !begins_result.ok();
46 
47   TF_RETURN_WITH_CONTEXT(ExtractConstVectorFromValue(slice_op.size(), &sizes),
48                          "expected constant argument for SliceOp::size()");
49 
50   return OkStatus();
51 }
52 
VerifySliceLayout(mlir::Operation * slice_op,mlir::Value value,const Layout & layout,llvm::ArrayRef<int64_t> * global_shape=nullptr)53 StatusOr<Layout> VerifySliceLayout(
54     mlir::Operation* slice_op, mlir::Value value, const Layout& layout,
55     llvm::ArrayRef<int64_t>* global_shape = nullptr) {
56   if (layout.IsFullyReplicated()) return layout;
57 
58   TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> shape,
59                       GetShapeOfValue(value, /*fail_on_dynamic=*/true));
60   const int64_t rank = shape.size();
61   if (global_shape != nullptr) {
62     // In ExpandOp, tensor shape is local shape. So, call site needs to provide
63     // global shape expliclity.
64     shape = *global_shape;
65   }
66 
67   llvm::SmallVector<int64_t, 4> begins, sizes;
68   bool dynamic_begins = false;
69   begins.reserve(rank);
70   sizes.reserve(rank);
71 
72   TF_RETURN_IF_ERROR(GetSliceOpArguments(
73       llvm::cast<mlir::TF::SliceOp>(slice_op), begins, dynamic_begins, sizes))
74 
75   auto num_shards = layout.num_shards();
76 
77   LayoutProto proposed_proto;
78   *proposed_proto.mutable_mesh_config() = layout.mesh().ToProto();
79   for (int64_t i = 0; i < rank; ++i) {
80     // Slice performed on replicated dimension translates to local expansion.
81     if (num_shards[i] == 1) {
82       proposed_proto.add_sharding_specs()->set_sharding_spec(
83           Layout::kUnshardedDim);
84       continue;
85     }
86 
87     const bool begins_starts_at_zero =
88         (sizes[i] == shape[i]) || (!dynamic_begins && begins[i] == 0);
89     const bool ends_at_full_size =
90         (sizes[i] == shape[i]) || (!dynamic_begins && sizes[i] == -1);
91 
92     if (begins_starts_at_zero && ends_at_full_size) {
93       // We support slicing with dynamic begins when the sharded dimensions are
94       // getting a full slice. Since we don't know the begins in this case, we
95       // need to rely in the sizes being static and equal to the global shape.
96       // In particular sizes[i] == shape[i] implies begins[i] == 0.
97       // A full slice over the any dimension can be performed locally.
98       proposed_proto.add_sharding_specs()->set_sharding_spec(
99           layout.sharding_spec(i));
100     } else {
101       // Slicing on sharded dim is not trivial. Propose an unsharded dim for
102       // that.
103       proposed_proto.add_sharding_specs()->set_sharding_spec(
104           Layout::kUnshardedDim);
105     }
106   }
107   return Layout::FromProto(proposed_proto);
108 }
109 
CalculateBitVector(const uint64_t mask_value)110 llvm::SmallVector<int64_t, 4> CalculateBitVector(const uint64_t mask_value) {
111   llvm::SmallVector<int64_t, 4> bit_vector;
112   bit_vector.resize(sizeof(uint64_t) * 8, 0);
113   for (int i = 0; i < sizeof(uint64_t) * 8; ++i) {
114     bit_vector[i] = (mask_value >> i & 1);
115   }
116   return bit_vector;
117 }
118 
119 // The begin/end/stride and the masks are all sized to mach the number of
120 // entries in the slice specification. E.g. [:, ..., 3] will have a begin/end/
121 // stride of size 3 and the max set bit in the mask will be the 3rd bit.
122 // This function converts this specifications into ones relative to the input
123 // tensor.
124 // We also output a bool vector of the input indices which are not shrunk away.
125 // These always must be replicated, since shrinking an index means we took a
126 // single element along that axis and it must be present on all cores.
127 // spec_to_input maps the 'spec' dimensions to the input dimensions. This is
128 // needed so we can create a new 'end' input for the SPMD expanded op.
129 //
130 // NOTE: If the begin or ends are dynamic, they will be size 0.
131 // If strides is dynamic it will be the correct rank but contain 0s (an invalid
132 // stride).
133 template <typename T>
GetInputOrientedData(T strided_slice,llvm::SmallVectorImpl<int64_t> * begin,uint64_t * begin_mask,llvm::SmallVectorImpl<int64_t> * end,uint64_t * end_mask,llvm::SmallVectorImpl<int64_t> * strides,llvm::SmallVectorImpl<bool> * not_shrunk,llvm::SmallVectorImpl<int64> * spec_to_input)134 Status GetInputOrientedData(T strided_slice,
135                             llvm::SmallVectorImpl<int64_t>* begin,
136                             uint64_t* begin_mask,
137                             llvm::SmallVectorImpl<int64_t>* end,
138                             uint64_t* end_mask,
139                             llvm::SmallVectorImpl<int64_t>* strides,
140                             llvm::SmallVectorImpl<bool>* not_shrunk,
141                             llvm::SmallVectorImpl<int64>* spec_to_input) {
142   begin->resize(0);
143   end->resize(0);
144   strides->resize(0);
145 
146   llvm::SmallVector<int64_t, 4> spec_begin;
147   llvm::SmallVector<int64_t, 4> spec_end;
148   llvm::SmallVector<int64_t, 4> spec_strides;
149 
150   TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> strides_shape,
151                       GetShapeOfValue(strided_slice.strides(),
152                                       /*fail_on_dynamic=*/true));
153   if (strides_shape.size() != 1)
154     return errors::InvalidArgument(
155         "strides input to strided operation is not rank 1");
156 
157   int64_t spec_rank = strides_shape[0];
158   spec_to_input->resize(spec_rank, -1);
159 
160   if (!ExtractConstVectorFromValue(strided_slice.strides(), &spec_strides).ok())
161     spec_strides.resize(spec_rank, 0);
162 
163   if (ExtractConstVectorFromValue(strided_slice.begin(), &spec_begin).ok())
164     if (spec_begin.size() != spec_rank)
165       return errors::InvalidArgument(
166           "rank of begin input to strided operation does not equal rank of "
167           "strides input");
168 
169   if (ExtractConstVectorFromValue(strided_slice.end(), &spec_end).ok())
170     if (spec_end.size() != spec_rank)
171       return errors::InvalidArgument(
172           "rank of end input to strided operation does not equal rank of "
173           "strides input");
174 
175   const uint64_t new_axis_mask = strided_slice.new_axis_mask();
176   const uint64_t shink_axis_mask = strided_slice.shrink_axis_mask();
177   const uint64_t spec_begin_mask = strided_slice.begin_mask();
178   const uint64_t spec_end_mask = strided_slice.end_mask();
179   uint64_t ellipsis_mask = strided_slice.ellipsis_mask();
180 
181   int64_t input_rank;
182   if (mlir::isa<mlir::TF::StridedSliceOp>(strided_slice) ||
183       mlir::isa<mlir::TF::TensorStridedSliceUpdateOp>(strided_slice)) {
184     // For StridedSlice the first operand is the input.
185     input_rank = ValueRank(strided_slice->getOperand(0));
186   } else if (mlir::isa<mlir::TF::StridedSliceGradOp>(strided_slice)) {
187     // For StridedSliceGrad the first operand is the shape of the input.
188     TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> input_shape,
189                         GetShapeOfValue(strided_slice->getOperand(0)));
190     if (input_shape.size() != 1)
191       return errors::InvalidArgument("input shape must be rank 1");
192     input_rank = input_shape[0];
193   }
194 
195   if (absl::popcount(ellipsis_mask) > 1)
196     return errors::InvalidArgument(
197         "strided slice only supports at most one ellipsis");
198 
199   // Count the number of axes after the ellipsis
200   bool found_ellipsis = false;
201   int64_t num_add_axis_after_ellipsis = 0;
202   for (int64_t i = 0; i < spec_rank; ++i) {
203     if (found_ellipsis && ((1 << i) & new_axis_mask))
204       num_add_axis_after_ellipsis++;
205     if ((1 << i) & ellipsis_mask) found_ellipsis = true;
206   }
207   // Guarantee one ellipsis. If there isn't one, add it at the end of the spec.
208   // If we do this, add one to the total rank so that we process the ellipsis as
209   // part of the loop below.
210   if (!found_ellipsis) ellipsis_mask |= (1 << (spec_rank++));
211 
212   // At this point total rank cannot be more than input_rank + number of
213   // new axes plus the number of ellipses. Check that condition so that we know
214   // the loop below won't have input_index >= input_rank.
215   if (spec_rank > input_rank + absl::popcount(new_axis_mask) + 1)
216     return errors::InvalidArgument(
217         "incompatible input rank, number of new axis and specification rank: ",
218         input_rank, ", ", absl::popcount(new_axis_mask), ", ", spec_rank);
219 
220   int64_t input_index = 0;
221   for (int64_t spec_index = 0; spec_index < spec_rank; ++spec_index) {
222     if ((1 << spec_index) & ellipsis_mask) {
223       const int64_t next_input_index =
224           std::min(input_rank - (spec_rank - spec_index) + 1 +
225                        num_add_axis_after_ellipsis,
226                    input_rank);
227       for (; input_index < next_input_index; input_index++) {
228         // For input axes within the ellipsis region, we include the entire axis
229         // by setting the begin and end mask.
230         not_shrunk->emplace_back(true);
231         if (!spec_begin.empty()) begin->emplace_back(0);
232         if (!spec_end.empty()) end->emplace_back(0);
233         strides->emplace_back(1);
234         (*begin_mask) |= 1 << input_index;
235         (*end_mask) |= 1 << input_index;
236       }
237     } else if (((1 << spec_index) & new_axis_mask) == 0) {
238       not_shrunk->emplace_back(((1 << spec_index) & shink_axis_mask) == 0);
239       if (!spec_begin.empty()) begin->emplace_back(spec_begin[spec_index]);
240       if (!spec_end.empty()) end->emplace_back(spec_end[spec_index]);
241       strides->emplace_back(spec_strides[spec_index]);
242       (*spec_to_input)[spec_index] = input_index;
243       (*begin_mask) |= ((spec_begin_mask >> spec_index) & 1) << input_index;
244       (*end_mask) |= ((spec_end_mask >> spec_index) & 1) << input_index;
245       input_index++;
246     }
247   }
248 
249   // This should not happen.
250   if (input_index != input_rank)
251     return errors::Internal("strided slice input not totally processed");
252 
253   return OkStatus();
254 }
255 
256 // Return an intermediate layout for StridedSlice(Grad), where we can lower the
257 // global StridedSlice(Grad) to a local one.
258 // All the inputs (begin/end/stride/masks) are sized to match the 'total rank'
259 // which is the rank of the input rank + number of new dimensions added (e.g
260 // the number of bits set in the new_axis_mask).
261 // The values of these inputs on the 'newly added' dimensions are ignored.
262 // global_input_shape is the global shape for the main input of StridedSlice or
263 // equivalently the global shape of the output of StridedSliceGrad.
264 // If new_end is not a nullptr, it will be set to the new ending vector if
265 // the end was constant, otherwise it will be cleared.
266 template <typename T>
GetStridedSliceIntermediateLayout(T strided_slice,const Layout & layout,const llvm::ArrayRef<int64_t> global_input_shape,llvm::SmallVectorImpl<int64_t> * new_end=nullptr)267 StatusOr<Layout> GetStridedSliceIntermediateLayout(
268     T strided_slice, const Layout& layout,
269     const llvm::ArrayRef<int64_t> global_input_shape,
270     llvm::SmallVectorImpl<int64_t>* new_end = nullptr) {
271   const int64_t rank = global_input_shape.size();
272 
273   // Records if the corresponding axis of the input can be sharded.
274   llvm::SmallVector<bool, 4> can_shard;
275   // Lists the start/end of the slice. Value is otherwise clamped to the correct
276   // range.
277   llvm::SmallVector<int64_t, 4> begin;
278   llvm::SmallVector<int64_t, 4> end;
279   // Lists the stride for each tensor dimension. Positive when its constant and
280   // 0 when its dynamic.
281   llvm::SmallVector<int64_t, 4> strides;
282   llvm::SmallVector<int64_t, 4> total_to_input;
283   // The current number of shards long each axis;
284   const std::vector<int32> shards = layout.num_shards();
285 
286   uint64_t begin_mask = 0;
287   uint64_t end_mask = 0;
288 
289   TF_RETURN_IF_ERROR(GetInputOrientedData(strided_slice, &begin, &begin_mask,
290                                           &end, &end_mask, &strides, &can_shard,
291                                           &total_to_input));
292 
293   bool const_begin = !begin.empty();
294   bool const_end = !end.empty();
295 
296   if (!const_begin) begin.resize(rank, 0);
297 
298   if (!const_end) end.resize(rank, 0);
299 
300   for (int i = 0; i < rank; ++i) {
301     if ((1 << i) & begin_mask)
302       begin[i] = 0;
303     else if (begin[i] < 0)
304       begin[i] += global_input_shape[i];
305 
306     if (begin[i] < 0l) {
307       begin[i] = 0l;
308     } else if (begin[i] > global_input_shape[i] - 1) {
309       begin[i] = global_input_shape[i] - 1;
310     }
311 
312     if ((1 << i) & end_mask)
313       end[i] = global_input_shape[i];
314     else if (end[i] < 0)
315       end[i] += global_input_shape[i];
316 
317     if (end[i] < 1l) {
318       end[i] = 1l;
319     } else if (end[i] > global_input_shape[i]) {
320       end[i] = global_input_shape[i];
321     }
322 
323     // Negative and dynamic stride requires unsharded axis.
324     if (strides[i] < 1) can_shard[i] = false;
325     // The local size must be divisible by the stride, otherwise the begin
326     // for each local slice would be different.
327     if ((global_input_shape[i] / shards[i]) % strides[i] != 0)
328       can_shard[i] = false;
329     // If start or end are dynamic we can't shard.
330     if (!(((1 << i) & begin_mask) || const_begin) ||
331         !(((1 << i) & end_mask) || const_end))
332       can_shard[i] = false;
333     // Finally if amount of space left on 'left' and 'right' of the tensor
334     // is more than (or equal to) a stride then we can't shard as there would be
335     // an unequal number of outputs per shard.
336     // NOTE: the case of end[i] == begin[i] may be a simple optimization since
337     // the result is an empty tensor.
338     if (global_input_shape[i] - (end[i] - begin[i]) >= strides[i])
339       can_shard[i] = false;
340     // If there is currently no sharding, it doesn't make sense to shard.
341     if (shards[i] == 1) can_shard[i] = false;
342   }
343 
344   // Compute the new 'end' for the slice. Note that this end needs to be in
345   // terms of the 'total' index not the input index (i.e. it needs 'bogus'
346   // entries for the new axes).
347   if (new_end != nullptr) {
348     if (!const_end) {
349       // Dynamic end are unchanged. We indicate this by ensuring the passed in
350       // is empty;
351       new_end->clear();
352     } else {
353       new_end->resize(total_to_input.size());
354       for (int i = 0; i < total_to_input.size(); ++i) {
355         const int64_t inp = total_to_input[i];
356         if (inp != -1) {
357           // If we can keep input axis input_index sharded, we need to update
358           // the end. Given the conditions we enforeced above, we can set end to
359           // the local size of input.
360           if (can_shard[inp])
361             (*new_end)[i] = global_input_shape[inp] / shards[inp];
362           else
363             (*new_end)[i] = end[inp];
364         }
365       }
366     }
367   }
368 
369   // Compute the new layout, its basically the old layout but replicated on some
370   // axis.
371   absl::flat_hash_set<int> reduced_dims;
372   for (int i = 0; i < rank; ++i)
373     if (!can_shard[i]) reduced_dims.emplace(i);
374   return layout.GetLayoutWithReducedDims(reduced_dims, /*keep_dims=*/true);
375 }
376 
377 enum Direction {
378   FORWARD,
379   BACKWARD,
380 };
381 
382 // Applies the shrink and new masks to a layout. This function works in both the
383 // forwards and backwards direction as specified in the direction argument.
384 template <typename SliceOpT>
ApplyNewAndShrinkMasksToLayout(SliceOpT slice_op,const int input_rank,const int output_rank,const Layout & proposed_layout,const Direction direction)385 StatusOr<Layout> ApplyNewAndShrinkMasksToLayout(SliceOpT slice_op,
386                                                 const int input_rank,
387                                                 const int output_rank,
388                                                 const Layout& proposed_layout,
389                                                 const Direction direction) {
390   // Calculate bit mask for shrunk dimensions/newly added dimensions.
391   const llvm::SmallVector<int64_t, 4> new_axis_mask =
392       CalculateBitVector(slice_op.new_axis_mask());
393   const llvm::SmallVector<int64_t, 4> shrink_axis_mask =
394       CalculateBitVector(slice_op.shrink_axis_mask());
395 
396   std::vector<std::string> sharding_spec;
397   int input_dim_index = 0;
398   int output_dim_index = 0;
399   int current_dimension_index = 0;
400   while (current_dimension_index < proposed_layout.rank()) {
401     if (input_dim_index < input_rank &&
402         shrink_axis_mask[input_dim_index] == 1) {
403       input_dim_index++;
404       if (direction == BACKWARD)
405         sharding_spec.emplace_back(Layout::kUnshardedDim);
406       else
407         current_dimension_index++;
408     } else if (output_dim_index < output_rank &&
409                new_axis_mask[output_dim_index] == 1) {
410       if (direction == FORWARD)
411         sharding_spec.emplace_back(Layout::kUnshardedDim);
412       else
413         current_dimension_index++;
414       output_dim_index++;
415     } else {
416       sharding_spec.emplace_back(
417           proposed_layout.sharding_spec(current_dimension_index));
418       input_dim_index++;
419       output_dim_index++;
420       current_dimension_index++;
421     }
422   }
423 
424   const auto& mask = (direction == FORWARD) ? new_axis_mask : shrink_axis_mask;
425   // New dimensions may be added after all dimensions have been sliced.
426   while (current_dimension_index < mask.size() &&
427          mask[current_dimension_index] == 1) {
428     sharding_spec.emplace_back(Layout::kUnshardedDim);
429     current_dimension_index++;
430   }
431 
432   return Layout::GetLayout(sharding_spec, proposed_layout.mesh());
433 }
434 
IntConstWithMatchingType(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int64_t> values,mlir::Type type)435 mlir::Value IntConstWithMatchingType(mlir::OpBuilder& builder,
436                                      mlir::Location loc,
437                                      llvm::ArrayRef<int64_t> values,
438                                      mlir::Type type) {
439   if (type.cast<mlir::RankedTensorType>().getElementType().isInteger(64)) {
440     return Int64Const(builder, loc, values);
441   } else {
442     llvm::SmallVector<int32, 4> values32(values.begin(), values.end());
443     return IntConst(builder, loc, values32);
444   }
445 }
446 
447 }  // namespace
448 
ExpandOp(mlir::Operation * op)449 StatusOr<mlir::Operation*> SliceSPMDExpander::ExpandOp(mlir::Operation* op) {
450   auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
451   TF_ASSIGN_OR_RETURN(auto input_layout,
452                       ExtractLayoutFromOperand(slice_op.input()));
453   TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
454 
455   if (!output_layout || !input_layout)
456     return errors::Unimplemented(
457         "layout of Slice op must be known before SPMD expansion.");
458 
459   // The dyn_cast will never be nullptr as it is checked in
460   // GetLayoutFromOperands.
461   auto input_type =
462       slice_op.input().getType().dyn_cast<mlir::RankedTensorType>();
463   if (!input_type)
464     return errors::InvalidArgument(
465         "rank of input tensor must be statically known for slice op.");
466 
467   TF_ASSIGN_OR_RETURN(auto global_shape,
468                       ExtractGlobalInputShape(op->getOpOperand(0)));
469   const int64_t input_rank = input_type.getRank();
470 
471   llvm::SmallVector<int64_t, 4> begins, sizes;
472   bool dynamic_begins = false;
473   begins.reserve(input_rank);
474   sizes.reserve(input_rank);
475 
476   TF_RETURN_IF_ERROR(
477       GetSliceOpArguments(slice_op, begins, dynamic_begins, sizes));
478 
479   TF_ASSIGN_OR_RETURN(auto proposed_layout,
480                       VerifySliceLayout(slice_op, slice_op.input(),
481                                         *input_layout, &global_shape));
482 
483   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
484 
485   TF_ASSIGN_OR_RETURN(auto relayout_input,
486                       EmitRelayout(op->getOperand(0), *input_layout,
487                                    proposed_layout, &newly_created_ops));
488   {
489     // Adjusts the sizes when it is full slicing on sharded dimension.
490     // Note that proposed layout is unsharded in the cases that:
491     // 1) We can't determine the begins and sizes != global shape
492     // 2) begins != 0
493     // 3) sizes != global shape or -1
494     const std::vector<int> num_shards = proposed_layout.num_shards();
495     for (int64_t i = 0; i < input_rank; ++i) {
496       if (num_shards[i] == 1) continue;
497 
498       if (sizes[i] == -1 && !dynamic_begins && begins[i] == 0) continue;
499 
500       if (sizes[i] == global_shape[i]) {
501         // Set the correct output size. If the input dynamic and this is -1,
502         // then shape inference can't tell the output shape.
503         sizes[i] = global_shape[i] / num_shards[i];
504         continue;
505       }
506 
507       return errors::InvalidArgument(
508           "Non-full-slicing on the sharded dimension is not allowed. "
509           "internal bug.");
510     }
511   }
512 
513   mlir::OpBuilder builder(op);
514   mlir::Value new_size;
515   auto loc = op->getLoc();
516   // Both begin and size need to be the same type, so we must match the new
517   // size input with the type of begin.
518   if (!slice_op.begin().getType().isa<mlir::ShapedType>())
519     return errors::Internal("type of begin is not a ShapedType");
520   mlir::ShapedType type = slice_op.begin().getType().cast<mlir::ShapedType>();
521   if (type.getElementType().isInteger(32))
522     new_size = IntConst(
523         builder, loc, llvm::SmallVector<int32, 4>(sizes.begin(), sizes.end()));
524   else
525     new_size = Int64Const(builder, loc, sizes);
526 
527   auto new_op =
528       builder
529           .create<mlir::TF::SliceOp>(loc, slice_op.output().getType(),
530                                      relayout_input, slice_op.begin(), new_size)
531           .getOperation();
532   new_op = InferSPMDExpandedLocalShape(new_op);
533 
534   TF_ASSIGN_OR_RETURN(auto relayout_output,
535                       EmitRelayout(new_op->getResult(0), proposed_layout,
536                                    *output_layout, &newly_created_ops));
537 
538   op->getOpResult(0).replaceAllUsesExcept(relayout_output, newly_created_ops);
539   op->erase();
540   return relayout_output.getDefiningOp();
541 }
542 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)543 StatusOr<llvm::DenseMap<int, Layout>> SliceSPMDExpander::ComputeLayoutForward(
544     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
545   // If the input layout is missing, don't return an output layout.
546   if (input_layouts.find(0) == input_layouts.end())
547     return llvm::DenseMap<int, Layout>();
548 
549   auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
550 
551   const Layout& input_layout = input_layouts.lookup(0);
552   TF_ASSIGN_OR_RETURN(
553       auto proposed_layout,
554       VerifySliceLayout(slice_op, slice_op.input(), input_layout));
555   return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
556 }
557 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)558 StatusOr<llvm::DenseMap<int, Layout>> SliceSPMDExpander::ComputeLayoutBackward(
559     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
560   auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
561   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
562 
563   llvm::DenseMap<int, Layout> input_layouts(slice_op.getNumOperands());
564   // Set replicated layout for begin and size operands.
565   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
566   input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
567 
568   // input
569   if (output_layouts.find(0) != output_layouts.end()) {
570     const Layout& output_layout = output_layouts.lookup(0);
571     TF_ASSIGN_OR_RETURN(
572         auto proposed_layout,
573         VerifySliceLayout(slice_op, slice_op.output(), output_layout));
574     input_layouts[0] = proposed_layout;
575   }
576 
577   return input_layouts;
578 }
579 
ExpandOp(mlir::Operation * op)580 StatusOr<mlir::Operation*> StridedSliceSPMDExpander::ExpandOp(
581     mlir::Operation* op) {
582   auto strided_slice_op = mlir::cast<mlir::TF::StridedSliceOp>(op);
583   TF_ASSIGN_OR_RETURN(Layout input_layout, ExtractRequiredLayoutFromOperand(
584                                                strided_slice_op.input()));
585   TF_ASSIGN_OR_RETURN(Layout output_layout,
586                       ExtractRequiredSingleLayoutFromOp(op));
587   TF_ASSIGN_OR_RETURN(
588       const llvm::ArrayRef<int64_t> global_input_shape,
589       GetGlobalShapeOfValueFromDTensorLayout(strided_slice_op.input()));
590 
591   llvm::SmallVector<int64_t, 4> end;
592   TF_ASSIGN_OR_RETURN(
593       Layout intermediate_input_layout,
594       GetStridedSliceIntermediateLayout(strided_slice_op, input_layout,
595                                         global_input_shape, &end));
596 
597   TF_ASSIGN_OR_RETURN(mlir::Value new_input,
598                       EmitRelayout(strided_slice_op.input(), input_layout,
599                                    intermediate_input_layout));
600 
601   strided_slice_op.inputMutable().assign(new_input);
602 
603   mlir::OpBuilder builder(op);
604 
605   if (!end.empty()) {
606     mlir::Value new_end =
607         IntConstWithMatchingType(builder, strided_slice_op.getLoc(), end,
608                                  strided_slice_op.begin().getType());
609     strided_slice_op.endMutable().assign(new_end);
610   }
611 
612   op = InferSPMDExpandedLocalShape(op);
613 
614   // Compute the layout of the output after the local StridedSlice takes place.
615   const int input_rank = global_input_shape.size();
616   const int output_rank = ValueRank(strided_slice_op.output());
617 
618   // Calculate bit mask for shrinked dimensions/newly added dimensions.
619   const llvm::SmallVector<int64_t, 4> new_axis_mask =
620       CalculateBitVector(strided_slice_op.new_axis_mask());
621   const llvm::SmallVector<int64_t, 4> shrink_axis_mask =
622       CalculateBitVector(strided_slice_op.shrink_axis_mask());
623 
624   TF_ASSIGN_OR_RETURN(
625       Layout intermediate_output_layout,
626       ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, output_rank,
627                                      intermediate_input_layout, FORWARD));
628 
629   // Do a final relayout to the correct output layout in case there are any
630   // differences between intermediate_output_layout and output_layout.
631   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
632 
633   TF_ASSIGN_OR_RETURN(
634       mlir::Value output,
635       EmitRelayout(strided_slice_op.output(), intermediate_output_layout,
636                    output_layout, &newly_created_ops));
637 
638   strided_slice_op.output().replaceAllUsesExcept(output, newly_created_ops);
639 
640   return output.getDefiningOp();
641 }
642 
643 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)644 StridedSliceSPMDExpander::ComputeLayoutForward(
645     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
646   // If the input layout is missing, don't return an output layout.
647   if (input_layouts.find(0) == input_layouts.end())
648     return llvm::DenseMap<int, Layout>();
649 
650   mlir::TF::StridedSliceOp strided_slice_op =
651       mlir::cast<mlir::TF::StridedSliceOp>(op);
652   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
653                       GetShapeOfValue(strided_slice_op.input(),
654                                       /*fail_on_dynamic=*/true));
655   const int input_rank = global_input_shape.size();
656   const int output_rank = ValueRank(strided_slice_op.output());
657 
658   const Layout& input_layout = input_layouts.lookup(0);
659   TF_ASSIGN_OR_RETURN(Layout proposed_layout,
660                       GetStridedSliceIntermediateLayout(
661                           strided_slice_op, input_layout, global_input_shape));
662   // If dimension was added or removed, create a new proposed output layout
663   // with dimensions added/skipped.
664   TF_ASSIGN_OR_RETURN(
665       proposed_layout,
666       ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, output_rank,
667                                      proposed_layout, FORWARD));
668   return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
669 }
670 
671 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)672 StridedSliceSPMDExpander::ComputeLayoutBackward(
673     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
674   mlir::TF::StridedSliceOp strided_slice_op =
675       mlir::cast<mlir::TF::StridedSliceOp>(op);
676   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
677 
678   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
679                       GetShapeOfValue(strided_slice_op.input(),
680                                       /*fail_on_dynamic=*/true));
681   const int input_rank = global_input_shape.size();
682   const int output_rank = ValueRank(strided_slice_op.output());
683 
684   llvm::DenseMap<int, Layout> input_layouts(strided_slice_op.getNumOperands());
685   // Set replicated layout for begin, end, and strides operands.
686   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
687   input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
688   input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
689 
690   // input
691   if (output_layouts.find(0) != output_layouts.end()) {
692     // This layout must exist (as there is only one output).
693     const Layout& output_layout = output_layouts.lookup(0);
694     // If dimension was added or removed, take the current output layout, and
695     // add/skip dimensions in it as needed to get an input layout.
696     TF_ASSIGN_OR_RETURN(
697         Layout proposed_layout,
698         ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank,
699                                        output_rank, output_layout, BACKWARD));
700     TF_ASSIGN_OR_RETURN(proposed_layout, GetStridedSliceIntermediateLayout(
701                                              strided_slice_op, proposed_layout,
702                                              global_input_shape));
703     input_layouts[0] = proposed_layout;
704   }
705 
706   return input_layouts;
707 }
708 
ExpandOp(mlir::Operation * op)709 StatusOr<mlir::Operation*> TensorStridedSliceUpdateSPMDExpander::ExpandOp(
710     mlir::Operation* op) {
711   mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
712       llvm::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
713   TF_ASSIGN_OR_RETURN(
714       const Layout input_layout,
715       ExtractRequiredLayoutFromOperand(strided_slice_op.input()));
716   TF_ASSIGN_OR_RETURN(
717       const Layout value_layout,
718       ExtractRequiredLayoutFromOperand(strided_slice_op.value()));
719   TF_ASSIGN_OR_RETURN(const Layout output_layout,
720                       ExtractRequiredSingleLayoutFromOp(op));
721 
722   TF_ASSIGN_OR_RETURN(
723       const llvm::ArrayRef<int64_t> global_input_shape,
724       GetGlobalShapeOfValueFromDTensorLayout(strided_slice_op.input()));
725 
726   const int input_rank = global_input_shape.size();
727   const int value_rank = ValueRank(strided_slice_op.value());
728 
729   llvm::SmallVector<int64_t, 4> end;
730   TF_ASSIGN_OR_RETURN(
731       Layout intermediate_input_layout,
732       GetStridedSliceIntermediateLayout(strided_slice_op, input_layout,
733                                         global_input_shape, &end));
734 
735   TF_ASSIGN_OR_RETURN(
736       Layout intermediate_value_layout,
737       ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, value_rank,
738                                      intermediate_input_layout, FORWARD));
739 
740   TF_ASSIGN_OR_RETURN(mlir::Value new_input,
741                       EmitRelayout(strided_slice_op.input(), input_layout,
742                                    intermediate_input_layout));
743 
744   TF_ASSIGN_OR_RETURN(mlir::Value new_value,
745                       EmitRelayout(strided_slice_op.value(), value_layout,
746                                    intermediate_value_layout));
747 
748   strided_slice_op.inputMutable().assign(new_input);
749   strided_slice_op.valueMutable().assign(new_value);
750 
751   mlir::OpBuilder builder(op);
752 
753   if (!end.empty()) {
754     mlir::Value new_end =
755         IntConstWithMatchingType(builder, strided_slice_op.getLoc(), end,
756                                  strided_slice_op.begin().getType());
757     strided_slice_op.endMutable().assign(new_end);
758   }
759 
760   op = InferSPMDExpandedLocalShape(op);
761 
762   // Do a final relayout to the correct output layout in case there are any
763   // differences between intermediate_output_layout and output_layout.
764   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
765 
766   TF_ASSIGN_OR_RETURN(
767       mlir::Value output,
768       EmitRelayout(strided_slice_op.output(), intermediate_input_layout,
769                    output_layout, &newly_created_ops));
770 
771   strided_slice_op.output().replaceAllUsesExcept(output, newly_created_ops);
772 
773   return output.getDefiningOp();
774 }
775 
776 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)777 TensorStridedSliceUpdateSPMDExpander::ComputeLayoutForward(
778     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
779   // If the input layout and value layout are missing, don't return an output
780   // layout.
781   if (input_layouts.find(0) == input_layouts.end() &&
782       input_layouts.find(4) == input_layouts.end())
783     return llvm::DenseMap<int, Layout>();
784 
785   mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
786       mlir::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
787   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
788                       GetShapeOfValue(strided_slice_op.input(),
789                                       /*fail_on_dynamic=*/true));
790   const int input_rank = global_input_shape.size();
791   const int value_rank = ValueRank(strided_slice_op.value());
792 
793   // We have a choice to determine the output layout, we will default to use
794   // input_layout if available, otherwise we will expand value_layout and use
795   // that.
796   Layout input_layout;
797   if (input_layouts.find(0) != input_layouts.end()) {
798     input_layout = input_layouts.lookup(0);
799   } else {
800     // When we don't have the input layout, use value layout to 'create' the
801     // input layout. We do this by applying the new and shrink masks backwards.
802     // This is because in the case of a normal strided slice the layout of
803     // value would be output layout.
804     const Layout& value_layout = input_layouts.lookup(4);
805     TF_ASSIGN_OR_RETURN(input_layout, ApplyNewAndShrinkMasksToLayout(
806                                           strided_slice_op, input_rank,
807                                           value_rank, value_layout, BACKWARD));
808   }
809   TF_ASSIGN_OR_RETURN(Layout proposed_output_layout,
810                       GetStridedSliceIntermediateLayout(
811                           strided_slice_op, input_layout, global_input_shape));
812 
813   return llvm::DenseMap<int, Layout>({{0, proposed_output_layout}});
814 }
815 
816 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)817 TensorStridedSliceUpdateSPMDExpander::ComputeLayoutBackward(
818     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
819   mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
820       mlir::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
821   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
822 
823   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
824                       GetShapeOfValue(strided_slice_op.input(),
825                                       /*fail_on_dynamic=*/true));
826   const int input_rank = global_input_shape.size();
827   const int value_rank = ValueRank(strided_slice_op.value());
828 
829   llvm::DenseMap<int, Layout> input_layouts(strided_slice_op.getNumOperands());
830   // Set replicated layout for begin, end, and strides operands.
831   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
832   input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
833   input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
834 
835   // input and value layouts
836   if (output_layouts.find(0) != output_layouts.end()) {
837     const Layout& output_layout = output_layouts.lookup(0);
838     TF_ASSIGN_OR_RETURN(
839         const Layout proposed_input_layout,
840         GetStridedSliceIntermediateLayout(strided_slice_op, output_layout,
841                                           global_input_shape));
842     input_layouts[0] = proposed_input_layout;
843 
844     // We also need a layout for value as well, and for that we just take the
845     // input layout and apply the masks.
846     // The layout of value is determined from the input layout by applying the
847     // new and shrink masks in the forwards direction as value would have been
848     // the output layout for a normal strided slice operation.
849     TF_ASSIGN_OR_RETURN(
850         const Layout proposed_value_layout,
851         ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, value_rank,
852                                        proposed_input_layout, FORWARD));
853     input_layouts[4] = proposed_value_layout;
854   }
855 
856   return input_layouts;
857 }
858 
ExpandOp(mlir::Operation * op)859 StatusOr<mlir::Operation*> StridedSliceGradSPMDExpander::ExpandOp(
860     mlir::Operation* op) {
861   auto strided_slice_grad_op = llvm::cast<mlir::TF::StridedSliceGradOp>(op);
862   TF_ASSIGN_OR_RETURN(
863       const Layout input_layout,
864       ExtractRequiredLayoutFromOperand(strided_slice_grad_op.dy()));
865   TF_ASSIGN_OR_RETURN(const Layout output_layout,
866                       ExtractRequiredSingleLayoutFromOp(op));
867 
868   TF_ASSIGN_OR_RETURN(
869       const llvm::ArrayRef<int64_t> global_output_shape,
870       GetGlobalShapeOfValueFromDTensorLayout(strided_slice_grad_op.output()));
871 
872   const int input_rank = ValueRank(strided_slice_grad_op.dy());
873   const int output_rank = global_output_shape.size();
874 
875   llvm::SmallVector<int64_t, 4> end;
876   TF_ASSIGN_OR_RETURN(
877       Layout intermediate_output_layout,
878       GetStridedSliceIntermediateLayout(strided_slice_grad_op, output_layout,
879                                         global_output_shape, &end));
880 
881   TF_ASSIGN_OR_RETURN(Layout intermediate_input_layout,
882                       ApplyNewAndShrinkMasksToLayout(
883                           strided_slice_grad_op, output_rank, input_rank,
884                           intermediate_output_layout, FORWARD));
885 
886   TF_ASSIGN_OR_RETURN(mlir::Value new_dy,
887                       EmitRelayout(strided_slice_grad_op.dy(), input_layout,
888                                    intermediate_input_layout));
889 
890   strided_slice_grad_op.dyMutable().assign(new_dy);
891 
892   mlir::OpBuilder builder(op);
893 
894   if (!end.empty()) {
895     mlir::Value new_end =
896         IntConstWithMatchingType(builder, strided_slice_grad_op.getLoc(), end,
897                                  strided_slice_grad_op.begin().getType());
898     strided_slice_grad_op.endMutable().assign(new_end);
899   }
900 
901   // The shape input to StridedSliceGrad will still be global, so we need to
902   // compute the local shape update it.
903   std::vector<int64_t> computed_output_shape =
904       intermediate_output_layout.LocalShapeFromGlobalShape(global_output_shape);
905   mlir::Value new_shape = IntConstWithMatchingType(
906       builder, strided_slice_grad_op.getLoc(), computed_output_shape,
907       strided_slice_grad_op.begin().getType());
908   strided_slice_grad_op.shapeMutable().assign(new_shape);
909 
910   op = InferSPMDExpandedLocalShape(op);
911 
912   // Do a final relayout to the correct output layout in case there are any
913   // differences between intermediate_output_layout and output_layout.
914   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
915 
916   TF_ASSIGN_OR_RETURN(
917       mlir::Value output,
918       EmitRelayout(strided_slice_grad_op.output(), intermediate_output_layout,
919                    output_layout, &newly_created_ops));
920 
921   strided_slice_grad_op.output().replaceAllUsesExcept(output,
922                                                       newly_created_ops);
923 
924   return output.getDefiningOp();
925 }
926 
927 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)928 StridedSliceGradSPMDExpander::ComputeLayoutForward(
929     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
930   // If the input layout is missing, don't return an output layout.
931   if (input_layouts.find(4) == input_layouts.end())
932     return llvm::DenseMap<int, Layout>();
933 
934   mlir::TF::StridedSliceGradOp strided_slice_grad_op =
935       mlir::cast<mlir::TF::StridedSliceGradOp>(op);
936   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_output_shape,
937                       GetShapeOfValue(strided_slice_grad_op.output(),
938                                       /*fail_on_dynamic=*/true));
939   const int input_rank = ValueRank(strided_slice_grad_op.dy());
940   const int output_rank = global_output_shape.size();
941 
942   const Layout& input_layout = input_layouts.lookup(4);
943   // If dimension was added or removed, take the current output layout, and
944   // add/skip dimensions in it as needed to get an input layout.
945   TF_ASSIGN_OR_RETURN(
946       Layout proposed_layout,
947       ApplyNewAndShrinkMasksToLayout(strided_slice_grad_op, output_rank,
948                                      input_rank, input_layout, BACKWARD));
949   TF_ASSIGN_OR_RETURN(
950       proposed_layout,
951       GetStridedSliceIntermediateLayout(strided_slice_grad_op, proposed_layout,
952                                         global_output_shape));
953   return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
954 }
955 
956 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)957 StridedSliceGradSPMDExpander::ComputeLayoutBackward(
958     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
959   mlir::TF::StridedSliceGradOp strided_slice_grad_op =
960       mlir::cast<mlir::TF::StridedSliceGradOp>(op);
961   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
962 
963   TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_output_shape,
964                       GetShapeOfValue(strided_slice_grad_op.output(),
965                                       /*fail_on_dynamic=*/true));
966   const int input_rank = ValueRank(strided_slice_grad_op.dy());
967   const int output_rank = global_output_shape.size();
968 
969   llvm::DenseMap<int, Layout> input_layouts(
970       strided_slice_grad_op.getNumOperands());
971   // Set replicated layout for shape, begin, end, stride operands.
972   input_layouts[0] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
973   input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
974   input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
975   input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
976 
977   // dy
978   if (output_layouts.find(0) != output_layouts.end()) {
979     const Layout& output_layout = output_layouts.lookup(0);
980     TF_ASSIGN_OR_RETURN(
981         Layout proposed_layout,
982         GetStridedSliceIntermediateLayout(strided_slice_grad_op, output_layout,
983                                           global_output_shape));
984 
985     // If dimension was added or removed, create a new proposed output layout
986     // with dimensions added/skipped.
987     TF_ASSIGN_OR_RETURN(
988         proposed_layout,
989         ApplyNewAndShrinkMasksToLayout(strided_slice_grad_op, output_rank,
990                                        input_rank, proposed_layout, FORWARD));
991     input_layouts[4] = proposed_layout;
992   }
993 
994   return input_layouts;
995 }
996 
997 }  // namespace dtensor
998 }  // namespace tensorflow
999