1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/slice_spmd_expander.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/dtensor/cc/dstatus.h"
29 #include "tensorflow/dtensor/mlir/collectives.h"
30 #include "tensorflow/dtensor/mlir/layout_parsing.h"
31 #include "tensorflow/dtensor/mlir/shape_utils.h"
32 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
33 #include "tensorflow/dtensor/mlir/value_utils.h"
34 #include "tensorflow/dtensor/proto/layout.pb.h"
35
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39
GetSliceOpArguments(mlir::TF::SliceOp slice_op,llvm::SmallVector<int64_t,4> & begins,bool & dynamic_begins,llvm::SmallVector<int64_t,4> & sizes)40 Status GetSliceOpArguments(mlir::TF::SliceOp slice_op,
41 llvm::SmallVector<int64_t, 4>& begins,
42 bool& dynamic_begins,
43 llvm::SmallVector<int64_t, 4>& sizes) {
44 Status begins_result = ExtractConstVectorFromValue(slice_op.begin(), &begins);
45 dynamic_begins = !begins_result.ok();
46
47 TF_RETURN_WITH_CONTEXT(ExtractConstVectorFromValue(slice_op.size(), &sizes),
48 "expected constant argument for SliceOp::size()");
49
50 return OkStatus();
51 }
52
VerifySliceLayout(mlir::Operation * slice_op,mlir::Value value,const Layout & layout,llvm::ArrayRef<int64_t> * global_shape=nullptr)53 StatusOr<Layout> VerifySliceLayout(
54 mlir::Operation* slice_op, mlir::Value value, const Layout& layout,
55 llvm::ArrayRef<int64_t>* global_shape = nullptr) {
56 if (layout.IsFullyReplicated()) return layout;
57
58 TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> shape,
59 GetShapeOfValue(value, /*fail_on_dynamic=*/true));
60 const int64_t rank = shape.size();
61 if (global_shape != nullptr) {
62 // In ExpandOp, tensor shape is local shape. So, call site needs to provide
63 // global shape expliclity.
64 shape = *global_shape;
65 }
66
67 llvm::SmallVector<int64_t, 4> begins, sizes;
68 bool dynamic_begins = false;
69 begins.reserve(rank);
70 sizes.reserve(rank);
71
72 TF_RETURN_IF_ERROR(GetSliceOpArguments(
73 llvm::cast<mlir::TF::SliceOp>(slice_op), begins, dynamic_begins, sizes))
74
75 auto num_shards = layout.num_shards();
76
77 LayoutProto proposed_proto;
78 *proposed_proto.mutable_mesh_config() = layout.mesh().ToProto();
79 for (int64_t i = 0; i < rank; ++i) {
80 // Slice performed on replicated dimension translates to local expansion.
81 if (num_shards[i] == 1) {
82 proposed_proto.add_sharding_specs()->set_sharding_spec(
83 Layout::kUnshardedDim);
84 continue;
85 }
86
87 const bool begins_starts_at_zero =
88 (sizes[i] == shape[i]) || (!dynamic_begins && begins[i] == 0);
89 const bool ends_at_full_size =
90 (sizes[i] == shape[i]) || (!dynamic_begins && sizes[i] == -1);
91
92 if (begins_starts_at_zero && ends_at_full_size) {
93 // We support slicing with dynamic begins when the sharded dimensions are
94 // getting a full slice. Since we don't know the begins in this case, we
95 // need to rely in the sizes being static and equal to the global shape.
96 // In particular sizes[i] == shape[i] implies begins[i] == 0.
97 // A full slice over the any dimension can be performed locally.
98 proposed_proto.add_sharding_specs()->set_sharding_spec(
99 layout.sharding_spec(i));
100 } else {
101 // Slicing on sharded dim is not trivial. Propose an unsharded dim for
102 // that.
103 proposed_proto.add_sharding_specs()->set_sharding_spec(
104 Layout::kUnshardedDim);
105 }
106 }
107 return Layout::FromProto(proposed_proto);
108 }
109
CalculateBitVector(const uint64_t mask_value)110 llvm::SmallVector<int64_t, 4> CalculateBitVector(const uint64_t mask_value) {
111 llvm::SmallVector<int64_t, 4> bit_vector;
112 bit_vector.resize(sizeof(uint64_t) * 8, 0);
113 for (int i = 0; i < sizeof(uint64_t) * 8; ++i) {
114 bit_vector[i] = (mask_value >> i & 1);
115 }
116 return bit_vector;
117 }
118
119 // The begin/end/stride and the masks are all sized to mach the number of
120 // entries in the slice specification. E.g. [:, ..., 3] will have a begin/end/
121 // stride of size 3 and the max set bit in the mask will be the 3rd bit.
122 // This function converts this specifications into ones relative to the input
123 // tensor.
124 // We also output a bool vector of the input indices which are not shrunk away.
125 // These always must be replicated, since shrinking an index means we took a
126 // single element along that axis and it must be present on all cores.
127 // spec_to_input maps the 'spec' dimensions to the input dimensions. This is
128 // needed so we can create a new 'end' input for the SPMD expanded op.
129 //
130 // NOTE: If the begin or ends are dynamic, they will be size 0.
131 // If strides is dynamic it will be the correct rank but contain 0s (an invalid
132 // stride).
133 template <typename T>
GetInputOrientedData(T strided_slice,llvm::SmallVectorImpl<int64_t> * begin,uint64_t * begin_mask,llvm::SmallVectorImpl<int64_t> * end,uint64_t * end_mask,llvm::SmallVectorImpl<int64_t> * strides,llvm::SmallVectorImpl<bool> * not_shrunk,llvm::SmallVectorImpl<int64> * spec_to_input)134 Status GetInputOrientedData(T strided_slice,
135 llvm::SmallVectorImpl<int64_t>* begin,
136 uint64_t* begin_mask,
137 llvm::SmallVectorImpl<int64_t>* end,
138 uint64_t* end_mask,
139 llvm::SmallVectorImpl<int64_t>* strides,
140 llvm::SmallVectorImpl<bool>* not_shrunk,
141 llvm::SmallVectorImpl<int64>* spec_to_input) {
142 begin->resize(0);
143 end->resize(0);
144 strides->resize(0);
145
146 llvm::SmallVector<int64_t, 4> spec_begin;
147 llvm::SmallVector<int64_t, 4> spec_end;
148 llvm::SmallVector<int64_t, 4> spec_strides;
149
150 TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> strides_shape,
151 GetShapeOfValue(strided_slice.strides(),
152 /*fail_on_dynamic=*/true));
153 if (strides_shape.size() != 1)
154 return errors::InvalidArgument(
155 "strides input to strided operation is not rank 1");
156
157 int64_t spec_rank = strides_shape[0];
158 spec_to_input->resize(spec_rank, -1);
159
160 if (!ExtractConstVectorFromValue(strided_slice.strides(), &spec_strides).ok())
161 spec_strides.resize(spec_rank, 0);
162
163 if (ExtractConstVectorFromValue(strided_slice.begin(), &spec_begin).ok())
164 if (spec_begin.size() != spec_rank)
165 return errors::InvalidArgument(
166 "rank of begin input to strided operation does not equal rank of "
167 "strides input");
168
169 if (ExtractConstVectorFromValue(strided_slice.end(), &spec_end).ok())
170 if (spec_end.size() != spec_rank)
171 return errors::InvalidArgument(
172 "rank of end input to strided operation does not equal rank of "
173 "strides input");
174
175 const uint64_t new_axis_mask = strided_slice.new_axis_mask();
176 const uint64_t shink_axis_mask = strided_slice.shrink_axis_mask();
177 const uint64_t spec_begin_mask = strided_slice.begin_mask();
178 const uint64_t spec_end_mask = strided_slice.end_mask();
179 uint64_t ellipsis_mask = strided_slice.ellipsis_mask();
180
181 int64_t input_rank;
182 if (mlir::isa<mlir::TF::StridedSliceOp>(strided_slice) ||
183 mlir::isa<mlir::TF::TensorStridedSliceUpdateOp>(strided_slice)) {
184 // For StridedSlice the first operand is the input.
185 input_rank = ValueRank(strided_slice->getOperand(0));
186 } else if (mlir::isa<mlir::TF::StridedSliceGradOp>(strided_slice)) {
187 // For StridedSliceGrad the first operand is the shape of the input.
188 TF_ASSIGN_OR_RETURN(llvm::ArrayRef<int64_t> input_shape,
189 GetShapeOfValue(strided_slice->getOperand(0)));
190 if (input_shape.size() != 1)
191 return errors::InvalidArgument("input shape must be rank 1");
192 input_rank = input_shape[0];
193 }
194
195 if (absl::popcount(ellipsis_mask) > 1)
196 return errors::InvalidArgument(
197 "strided slice only supports at most one ellipsis");
198
199 // Count the number of axes after the ellipsis
200 bool found_ellipsis = false;
201 int64_t num_add_axis_after_ellipsis = 0;
202 for (int64_t i = 0; i < spec_rank; ++i) {
203 if (found_ellipsis && ((1 << i) & new_axis_mask))
204 num_add_axis_after_ellipsis++;
205 if ((1 << i) & ellipsis_mask) found_ellipsis = true;
206 }
207 // Guarantee one ellipsis. If there isn't one, add it at the end of the spec.
208 // If we do this, add one to the total rank so that we process the ellipsis as
209 // part of the loop below.
210 if (!found_ellipsis) ellipsis_mask |= (1 << (spec_rank++));
211
212 // At this point total rank cannot be more than input_rank + number of
213 // new axes plus the number of ellipses. Check that condition so that we know
214 // the loop below won't have input_index >= input_rank.
215 if (spec_rank > input_rank + absl::popcount(new_axis_mask) + 1)
216 return errors::InvalidArgument(
217 "incompatible input rank, number of new axis and specification rank: ",
218 input_rank, ", ", absl::popcount(new_axis_mask), ", ", spec_rank);
219
220 int64_t input_index = 0;
221 for (int64_t spec_index = 0; spec_index < spec_rank; ++spec_index) {
222 if ((1 << spec_index) & ellipsis_mask) {
223 const int64_t next_input_index =
224 std::min(input_rank - (spec_rank - spec_index) + 1 +
225 num_add_axis_after_ellipsis,
226 input_rank);
227 for (; input_index < next_input_index; input_index++) {
228 // For input axes within the ellipsis region, we include the entire axis
229 // by setting the begin and end mask.
230 not_shrunk->emplace_back(true);
231 if (!spec_begin.empty()) begin->emplace_back(0);
232 if (!spec_end.empty()) end->emplace_back(0);
233 strides->emplace_back(1);
234 (*begin_mask) |= 1 << input_index;
235 (*end_mask) |= 1 << input_index;
236 }
237 } else if (((1 << spec_index) & new_axis_mask) == 0) {
238 not_shrunk->emplace_back(((1 << spec_index) & shink_axis_mask) == 0);
239 if (!spec_begin.empty()) begin->emplace_back(spec_begin[spec_index]);
240 if (!spec_end.empty()) end->emplace_back(spec_end[spec_index]);
241 strides->emplace_back(spec_strides[spec_index]);
242 (*spec_to_input)[spec_index] = input_index;
243 (*begin_mask) |= ((spec_begin_mask >> spec_index) & 1) << input_index;
244 (*end_mask) |= ((spec_end_mask >> spec_index) & 1) << input_index;
245 input_index++;
246 }
247 }
248
249 // This should not happen.
250 if (input_index != input_rank)
251 return errors::Internal("strided slice input not totally processed");
252
253 return OkStatus();
254 }
255
256 // Return an intermediate layout for StridedSlice(Grad), where we can lower the
257 // global StridedSlice(Grad) to a local one.
258 // All the inputs (begin/end/stride/masks) are sized to match the 'total rank'
259 // which is the rank of the input rank + number of new dimensions added (e.g
260 // the number of bits set in the new_axis_mask).
261 // The values of these inputs on the 'newly added' dimensions are ignored.
262 // global_input_shape is the global shape for the main input of StridedSlice or
263 // equivalently the global shape of the output of StridedSliceGrad.
264 // If new_end is not a nullptr, it will be set to the new ending vector if
265 // the end was constant, otherwise it will be cleared.
266 template <typename T>
GetStridedSliceIntermediateLayout(T strided_slice,const Layout & layout,const llvm::ArrayRef<int64_t> global_input_shape,llvm::SmallVectorImpl<int64_t> * new_end=nullptr)267 StatusOr<Layout> GetStridedSliceIntermediateLayout(
268 T strided_slice, const Layout& layout,
269 const llvm::ArrayRef<int64_t> global_input_shape,
270 llvm::SmallVectorImpl<int64_t>* new_end = nullptr) {
271 const int64_t rank = global_input_shape.size();
272
273 // Records if the corresponding axis of the input can be sharded.
274 llvm::SmallVector<bool, 4> can_shard;
275 // Lists the start/end of the slice. Value is otherwise clamped to the correct
276 // range.
277 llvm::SmallVector<int64_t, 4> begin;
278 llvm::SmallVector<int64_t, 4> end;
279 // Lists the stride for each tensor dimension. Positive when its constant and
280 // 0 when its dynamic.
281 llvm::SmallVector<int64_t, 4> strides;
282 llvm::SmallVector<int64_t, 4> total_to_input;
283 // The current number of shards long each axis;
284 const std::vector<int32> shards = layout.num_shards();
285
286 uint64_t begin_mask = 0;
287 uint64_t end_mask = 0;
288
289 TF_RETURN_IF_ERROR(GetInputOrientedData(strided_slice, &begin, &begin_mask,
290 &end, &end_mask, &strides, &can_shard,
291 &total_to_input));
292
293 bool const_begin = !begin.empty();
294 bool const_end = !end.empty();
295
296 if (!const_begin) begin.resize(rank, 0);
297
298 if (!const_end) end.resize(rank, 0);
299
300 for (int i = 0; i < rank; ++i) {
301 if ((1 << i) & begin_mask)
302 begin[i] = 0;
303 else if (begin[i] < 0)
304 begin[i] += global_input_shape[i];
305
306 if (begin[i] < 0l) {
307 begin[i] = 0l;
308 } else if (begin[i] > global_input_shape[i] - 1) {
309 begin[i] = global_input_shape[i] - 1;
310 }
311
312 if ((1 << i) & end_mask)
313 end[i] = global_input_shape[i];
314 else if (end[i] < 0)
315 end[i] += global_input_shape[i];
316
317 if (end[i] < 1l) {
318 end[i] = 1l;
319 } else if (end[i] > global_input_shape[i]) {
320 end[i] = global_input_shape[i];
321 }
322
323 // Negative and dynamic stride requires unsharded axis.
324 if (strides[i] < 1) can_shard[i] = false;
325 // The local size must be divisible by the stride, otherwise the begin
326 // for each local slice would be different.
327 if ((global_input_shape[i] / shards[i]) % strides[i] != 0)
328 can_shard[i] = false;
329 // If start or end are dynamic we can't shard.
330 if (!(((1 << i) & begin_mask) || const_begin) ||
331 !(((1 << i) & end_mask) || const_end))
332 can_shard[i] = false;
333 // Finally if amount of space left on 'left' and 'right' of the tensor
334 // is more than (or equal to) a stride then we can't shard as there would be
335 // an unequal number of outputs per shard.
336 // NOTE: the case of end[i] == begin[i] may be a simple optimization since
337 // the result is an empty tensor.
338 if (global_input_shape[i] - (end[i] - begin[i]) >= strides[i])
339 can_shard[i] = false;
340 // If there is currently no sharding, it doesn't make sense to shard.
341 if (shards[i] == 1) can_shard[i] = false;
342 }
343
344 // Compute the new 'end' for the slice. Note that this end needs to be in
345 // terms of the 'total' index not the input index (i.e. it needs 'bogus'
346 // entries for the new axes).
347 if (new_end != nullptr) {
348 if (!const_end) {
349 // Dynamic end are unchanged. We indicate this by ensuring the passed in
350 // is empty;
351 new_end->clear();
352 } else {
353 new_end->resize(total_to_input.size());
354 for (int i = 0; i < total_to_input.size(); ++i) {
355 const int64_t inp = total_to_input[i];
356 if (inp != -1) {
357 // If we can keep input axis input_index sharded, we need to update
358 // the end. Given the conditions we enforeced above, we can set end to
359 // the local size of input.
360 if (can_shard[inp])
361 (*new_end)[i] = global_input_shape[inp] / shards[inp];
362 else
363 (*new_end)[i] = end[inp];
364 }
365 }
366 }
367 }
368
369 // Compute the new layout, its basically the old layout but replicated on some
370 // axis.
371 absl::flat_hash_set<int> reduced_dims;
372 for (int i = 0; i < rank; ++i)
373 if (!can_shard[i]) reduced_dims.emplace(i);
374 return layout.GetLayoutWithReducedDims(reduced_dims, /*keep_dims=*/true);
375 }
376
377 enum Direction {
378 FORWARD,
379 BACKWARD,
380 };
381
382 // Applies the shrink and new masks to a layout. This function works in both the
383 // forwards and backwards direction as specified in the direction argument.
384 template <typename SliceOpT>
ApplyNewAndShrinkMasksToLayout(SliceOpT slice_op,const int input_rank,const int output_rank,const Layout & proposed_layout,const Direction direction)385 StatusOr<Layout> ApplyNewAndShrinkMasksToLayout(SliceOpT slice_op,
386 const int input_rank,
387 const int output_rank,
388 const Layout& proposed_layout,
389 const Direction direction) {
390 // Calculate bit mask for shrunk dimensions/newly added dimensions.
391 const llvm::SmallVector<int64_t, 4> new_axis_mask =
392 CalculateBitVector(slice_op.new_axis_mask());
393 const llvm::SmallVector<int64_t, 4> shrink_axis_mask =
394 CalculateBitVector(slice_op.shrink_axis_mask());
395
396 std::vector<std::string> sharding_spec;
397 int input_dim_index = 0;
398 int output_dim_index = 0;
399 int current_dimension_index = 0;
400 while (current_dimension_index < proposed_layout.rank()) {
401 if (input_dim_index < input_rank &&
402 shrink_axis_mask[input_dim_index] == 1) {
403 input_dim_index++;
404 if (direction == BACKWARD)
405 sharding_spec.emplace_back(Layout::kUnshardedDim);
406 else
407 current_dimension_index++;
408 } else if (output_dim_index < output_rank &&
409 new_axis_mask[output_dim_index] == 1) {
410 if (direction == FORWARD)
411 sharding_spec.emplace_back(Layout::kUnshardedDim);
412 else
413 current_dimension_index++;
414 output_dim_index++;
415 } else {
416 sharding_spec.emplace_back(
417 proposed_layout.sharding_spec(current_dimension_index));
418 input_dim_index++;
419 output_dim_index++;
420 current_dimension_index++;
421 }
422 }
423
424 const auto& mask = (direction == FORWARD) ? new_axis_mask : shrink_axis_mask;
425 // New dimensions may be added after all dimensions have been sliced.
426 while (current_dimension_index < mask.size() &&
427 mask[current_dimension_index] == 1) {
428 sharding_spec.emplace_back(Layout::kUnshardedDim);
429 current_dimension_index++;
430 }
431
432 return Layout::GetLayout(sharding_spec, proposed_layout.mesh());
433 }
434
IntConstWithMatchingType(mlir::OpBuilder & builder,mlir::Location loc,llvm::ArrayRef<int64_t> values,mlir::Type type)435 mlir::Value IntConstWithMatchingType(mlir::OpBuilder& builder,
436 mlir::Location loc,
437 llvm::ArrayRef<int64_t> values,
438 mlir::Type type) {
439 if (type.cast<mlir::RankedTensorType>().getElementType().isInteger(64)) {
440 return Int64Const(builder, loc, values);
441 } else {
442 llvm::SmallVector<int32, 4> values32(values.begin(), values.end());
443 return IntConst(builder, loc, values32);
444 }
445 }
446
447 } // namespace
448
ExpandOp(mlir::Operation * op)449 StatusOr<mlir::Operation*> SliceSPMDExpander::ExpandOp(mlir::Operation* op) {
450 auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
451 TF_ASSIGN_OR_RETURN(auto input_layout,
452 ExtractLayoutFromOperand(slice_op.input()));
453 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
454
455 if (!output_layout || !input_layout)
456 return errors::Unimplemented(
457 "layout of Slice op must be known before SPMD expansion.");
458
459 // The dyn_cast will never be nullptr as it is checked in
460 // GetLayoutFromOperands.
461 auto input_type =
462 slice_op.input().getType().dyn_cast<mlir::RankedTensorType>();
463 if (!input_type)
464 return errors::InvalidArgument(
465 "rank of input tensor must be statically known for slice op.");
466
467 TF_ASSIGN_OR_RETURN(auto global_shape,
468 ExtractGlobalInputShape(op->getOpOperand(0)));
469 const int64_t input_rank = input_type.getRank();
470
471 llvm::SmallVector<int64_t, 4> begins, sizes;
472 bool dynamic_begins = false;
473 begins.reserve(input_rank);
474 sizes.reserve(input_rank);
475
476 TF_RETURN_IF_ERROR(
477 GetSliceOpArguments(slice_op, begins, dynamic_begins, sizes));
478
479 TF_ASSIGN_OR_RETURN(auto proposed_layout,
480 VerifySliceLayout(slice_op, slice_op.input(),
481 *input_layout, &global_shape));
482
483 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
484
485 TF_ASSIGN_OR_RETURN(auto relayout_input,
486 EmitRelayout(op->getOperand(0), *input_layout,
487 proposed_layout, &newly_created_ops));
488 {
489 // Adjusts the sizes when it is full slicing on sharded dimension.
490 // Note that proposed layout is unsharded in the cases that:
491 // 1) We can't determine the begins and sizes != global shape
492 // 2) begins != 0
493 // 3) sizes != global shape or -1
494 const std::vector<int> num_shards = proposed_layout.num_shards();
495 for (int64_t i = 0; i < input_rank; ++i) {
496 if (num_shards[i] == 1) continue;
497
498 if (sizes[i] == -1 && !dynamic_begins && begins[i] == 0) continue;
499
500 if (sizes[i] == global_shape[i]) {
501 // Set the correct output size. If the input dynamic and this is -1,
502 // then shape inference can't tell the output shape.
503 sizes[i] = global_shape[i] / num_shards[i];
504 continue;
505 }
506
507 return errors::InvalidArgument(
508 "Non-full-slicing on the sharded dimension is not allowed. "
509 "internal bug.");
510 }
511 }
512
513 mlir::OpBuilder builder(op);
514 mlir::Value new_size;
515 auto loc = op->getLoc();
516 // Both begin and size need to be the same type, so we must match the new
517 // size input with the type of begin.
518 if (!slice_op.begin().getType().isa<mlir::ShapedType>())
519 return errors::Internal("type of begin is not a ShapedType");
520 mlir::ShapedType type = slice_op.begin().getType().cast<mlir::ShapedType>();
521 if (type.getElementType().isInteger(32))
522 new_size = IntConst(
523 builder, loc, llvm::SmallVector<int32, 4>(sizes.begin(), sizes.end()));
524 else
525 new_size = Int64Const(builder, loc, sizes);
526
527 auto new_op =
528 builder
529 .create<mlir::TF::SliceOp>(loc, slice_op.output().getType(),
530 relayout_input, slice_op.begin(), new_size)
531 .getOperation();
532 new_op = InferSPMDExpandedLocalShape(new_op);
533
534 TF_ASSIGN_OR_RETURN(auto relayout_output,
535 EmitRelayout(new_op->getResult(0), proposed_layout,
536 *output_layout, &newly_created_ops));
537
538 op->getOpResult(0).replaceAllUsesExcept(relayout_output, newly_created_ops);
539 op->erase();
540 return relayout_output.getDefiningOp();
541 }
542
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)543 StatusOr<llvm::DenseMap<int, Layout>> SliceSPMDExpander::ComputeLayoutForward(
544 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
545 // If the input layout is missing, don't return an output layout.
546 if (input_layouts.find(0) == input_layouts.end())
547 return llvm::DenseMap<int, Layout>();
548
549 auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
550
551 const Layout& input_layout = input_layouts.lookup(0);
552 TF_ASSIGN_OR_RETURN(
553 auto proposed_layout,
554 VerifySliceLayout(slice_op, slice_op.input(), input_layout));
555 return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
556 }
557
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)558 StatusOr<llvm::DenseMap<int, Layout>> SliceSPMDExpander::ComputeLayoutBackward(
559 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
560 auto slice_op = mlir::cast<mlir::TF::SliceOp>(op);
561 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
562
563 llvm::DenseMap<int, Layout> input_layouts(slice_op.getNumOperands());
564 // Set replicated layout for begin and size operands.
565 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
566 input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
567
568 // input
569 if (output_layouts.find(0) != output_layouts.end()) {
570 const Layout& output_layout = output_layouts.lookup(0);
571 TF_ASSIGN_OR_RETURN(
572 auto proposed_layout,
573 VerifySliceLayout(slice_op, slice_op.output(), output_layout));
574 input_layouts[0] = proposed_layout;
575 }
576
577 return input_layouts;
578 }
579
ExpandOp(mlir::Operation * op)580 StatusOr<mlir::Operation*> StridedSliceSPMDExpander::ExpandOp(
581 mlir::Operation* op) {
582 auto strided_slice_op = mlir::cast<mlir::TF::StridedSliceOp>(op);
583 TF_ASSIGN_OR_RETURN(Layout input_layout, ExtractRequiredLayoutFromOperand(
584 strided_slice_op.input()));
585 TF_ASSIGN_OR_RETURN(Layout output_layout,
586 ExtractRequiredSingleLayoutFromOp(op));
587 TF_ASSIGN_OR_RETURN(
588 const llvm::ArrayRef<int64_t> global_input_shape,
589 GetGlobalShapeOfValueFromDTensorLayout(strided_slice_op.input()));
590
591 llvm::SmallVector<int64_t, 4> end;
592 TF_ASSIGN_OR_RETURN(
593 Layout intermediate_input_layout,
594 GetStridedSliceIntermediateLayout(strided_slice_op, input_layout,
595 global_input_shape, &end));
596
597 TF_ASSIGN_OR_RETURN(mlir::Value new_input,
598 EmitRelayout(strided_slice_op.input(), input_layout,
599 intermediate_input_layout));
600
601 strided_slice_op.inputMutable().assign(new_input);
602
603 mlir::OpBuilder builder(op);
604
605 if (!end.empty()) {
606 mlir::Value new_end =
607 IntConstWithMatchingType(builder, strided_slice_op.getLoc(), end,
608 strided_slice_op.begin().getType());
609 strided_slice_op.endMutable().assign(new_end);
610 }
611
612 op = InferSPMDExpandedLocalShape(op);
613
614 // Compute the layout of the output after the local StridedSlice takes place.
615 const int input_rank = global_input_shape.size();
616 const int output_rank = ValueRank(strided_slice_op.output());
617
618 // Calculate bit mask for shrinked dimensions/newly added dimensions.
619 const llvm::SmallVector<int64_t, 4> new_axis_mask =
620 CalculateBitVector(strided_slice_op.new_axis_mask());
621 const llvm::SmallVector<int64_t, 4> shrink_axis_mask =
622 CalculateBitVector(strided_slice_op.shrink_axis_mask());
623
624 TF_ASSIGN_OR_RETURN(
625 Layout intermediate_output_layout,
626 ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, output_rank,
627 intermediate_input_layout, FORWARD));
628
629 // Do a final relayout to the correct output layout in case there are any
630 // differences between intermediate_output_layout and output_layout.
631 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
632
633 TF_ASSIGN_OR_RETURN(
634 mlir::Value output,
635 EmitRelayout(strided_slice_op.output(), intermediate_output_layout,
636 output_layout, &newly_created_ops));
637
638 strided_slice_op.output().replaceAllUsesExcept(output, newly_created_ops);
639
640 return output.getDefiningOp();
641 }
642
643 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)644 StridedSliceSPMDExpander::ComputeLayoutForward(
645 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
646 // If the input layout is missing, don't return an output layout.
647 if (input_layouts.find(0) == input_layouts.end())
648 return llvm::DenseMap<int, Layout>();
649
650 mlir::TF::StridedSliceOp strided_slice_op =
651 mlir::cast<mlir::TF::StridedSliceOp>(op);
652 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
653 GetShapeOfValue(strided_slice_op.input(),
654 /*fail_on_dynamic=*/true));
655 const int input_rank = global_input_shape.size();
656 const int output_rank = ValueRank(strided_slice_op.output());
657
658 const Layout& input_layout = input_layouts.lookup(0);
659 TF_ASSIGN_OR_RETURN(Layout proposed_layout,
660 GetStridedSliceIntermediateLayout(
661 strided_slice_op, input_layout, global_input_shape));
662 // If dimension was added or removed, create a new proposed output layout
663 // with dimensions added/skipped.
664 TF_ASSIGN_OR_RETURN(
665 proposed_layout,
666 ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, output_rank,
667 proposed_layout, FORWARD));
668 return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
669 }
670
671 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)672 StridedSliceSPMDExpander::ComputeLayoutBackward(
673 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
674 mlir::TF::StridedSliceOp strided_slice_op =
675 mlir::cast<mlir::TF::StridedSliceOp>(op);
676 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
677
678 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
679 GetShapeOfValue(strided_slice_op.input(),
680 /*fail_on_dynamic=*/true));
681 const int input_rank = global_input_shape.size();
682 const int output_rank = ValueRank(strided_slice_op.output());
683
684 llvm::DenseMap<int, Layout> input_layouts(strided_slice_op.getNumOperands());
685 // Set replicated layout for begin, end, and strides operands.
686 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
687 input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
688 input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
689
690 // input
691 if (output_layouts.find(0) != output_layouts.end()) {
692 // This layout must exist (as there is only one output).
693 const Layout& output_layout = output_layouts.lookup(0);
694 // If dimension was added or removed, take the current output layout, and
695 // add/skip dimensions in it as needed to get an input layout.
696 TF_ASSIGN_OR_RETURN(
697 Layout proposed_layout,
698 ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank,
699 output_rank, output_layout, BACKWARD));
700 TF_ASSIGN_OR_RETURN(proposed_layout, GetStridedSliceIntermediateLayout(
701 strided_slice_op, proposed_layout,
702 global_input_shape));
703 input_layouts[0] = proposed_layout;
704 }
705
706 return input_layouts;
707 }
708
ExpandOp(mlir::Operation * op)709 StatusOr<mlir::Operation*> TensorStridedSliceUpdateSPMDExpander::ExpandOp(
710 mlir::Operation* op) {
711 mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
712 llvm::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
713 TF_ASSIGN_OR_RETURN(
714 const Layout input_layout,
715 ExtractRequiredLayoutFromOperand(strided_slice_op.input()));
716 TF_ASSIGN_OR_RETURN(
717 const Layout value_layout,
718 ExtractRequiredLayoutFromOperand(strided_slice_op.value()));
719 TF_ASSIGN_OR_RETURN(const Layout output_layout,
720 ExtractRequiredSingleLayoutFromOp(op));
721
722 TF_ASSIGN_OR_RETURN(
723 const llvm::ArrayRef<int64_t> global_input_shape,
724 GetGlobalShapeOfValueFromDTensorLayout(strided_slice_op.input()));
725
726 const int input_rank = global_input_shape.size();
727 const int value_rank = ValueRank(strided_slice_op.value());
728
729 llvm::SmallVector<int64_t, 4> end;
730 TF_ASSIGN_OR_RETURN(
731 Layout intermediate_input_layout,
732 GetStridedSliceIntermediateLayout(strided_slice_op, input_layout,
733 global_input_shape, &end));
734
735 TF_ASSIGN_OR_RETURN(
736 Layout intermediate_value_layout,
737 ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, value_rank,
738 intermediate_input_layout, FORWARD));
739
740 TF_ASSIGN_OR_RETURN(mlir::Value new_input,
741 EmitRelayout(strided_slice_op.input(), input_layout,
742 intermediate_input_layout));
743
744 TF_ASSIGN_OR_RETURN(mlir::Value new_value,
745 EmitRelayout(strided_slice_op.value(), value_layout,
746 intermediate_value_layout));
747
748 strided_slice_op.inputMutable().assign(new_input);
749 strided_slice_op.valueMutable().assign(new_value);
750
751 mlir::OpBuilder builder(op);
752
753 if (!end.empty()) {
754 mlir::Value new_end =
755 IntConstWithMatchingType(builder, strided_slice_op.getLoc(), end,
756 strided_slice_op.begin().getType());
757 strided_slice_op.endMutable().assign(new_end);
758 }
759
760 op = InferSPMDExpandedLocalShape(op);
761
762 // Do a final relayout to the correct output layout in case there are any
763 // differences between intermediate_output_layout and output_layout.
764 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
765
766 TF_ASSIGN_OR_RETURN(
767 mlir::Value output,
768 EmitRelayout(strided_slice_op.output(), intermediate_input_layout,
769 output_layout, &newly_created_ops));
770
771 strided_slice_op.output().replaceAllUsesExcept(output, newly_created_ops);
772
773 return output.getDefiningOp();
774 }
775
776 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)777 TensorStridedSliceUpdateSPMDExpander::ComputeLayoutForward(
778 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
779 // If the input layout and value layout are missing, don't return an output
780 // layout.
781 if (input_layouts.find(0) == input_layouts.end() &&
782 input_layouts.find(4) == input_layouts.end())
783 return llvm::DenseMap<int, Layout>();
784
785 mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
786 mlir::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
787 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
788 GetShapeOfValue(strided_slice_op.input(),
789 /*fail_on_dynamic=*/true));
790 const int input_rank = global_input_shape.size();
791 const int value_rank = ValueRank(strided_slice_op.value());
792
793 // We have a choice to determine the output layout, we will default to use
794 // input_layout if available, otherwise we will expand value_layout and use
795 // that.
796 Layout input_layout;
797 if (input_layouts.find(0) != input_layouts.end()) {
798 input_layout = input_layouts.lookup(0);
799 } else {
800 // When we don't have the input layout, use value layout to 'create' the
801 // input layout. We do this by applying the new and shrink masks backwards.
802 // This is because in the case of a normal strided slice the layout of
803 // value would be output layout.
804 const Layout& value_layout = input_layouts.lookup(4);
805 TF_ASSIGN_OR_RETURN(input_layout, ApplyNewAndShrinkMasksToLayout(
806 strided_slice_op, input_rank,
807 value_rank, value_layout, BACKWARD));
808 }
809 TF_ASSIGN_OR_RETURN(Layout proposed_output_layout,
810 GetStridedSliceIntermediateLayout(
811 strided_slice_op, input_layout, global_input_shape));
812
813 return llvm::DenseMap<int, Layout>({{0, proposed_output_layout}});
814 }
815
816 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)817 TensorStridedSliceUpdateSPMDExpander::ComputeLayoutBackward(
818 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
819 mlir::TF::TensorStridedSliceUpdateOp strided_slice_op =
820 mlir::cast<mlir::TF::TensorStridedSliceUpdateOp>(op);
821 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
822
823 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_input_shape,
824 GetShapeOfValue(strided_slice_op.input(),
825 /*fail_on_dynamic=*/true));
826 const int input_rank = global_input_shape.size();
827 const int value_rank = ValueRank(strided_slice_op.value());
828
829 llvm::DenseMap<int, Layout> input_layouts(strided_slice_op.getNumOperands());
830 // Set replicated layout for begin, end, and strides operands.
831 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
832 input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
833 input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
834
835 // input and value layouts
836 if (output_layouts.find(0) != output_layouts.end()) {
837 const Layout& output_layout = output_layouts.lookup(0);
838 TF_ASSIGN_OR_RETURN(
839 const Layout proposed_input_layout,
840 GetStridedSliceIntermediateLayout(strided_slice_op, output_layout,
841 global_input_shape));
842 input_layouts[0] = proposed_input_layout;
843
844 // We also need a layout for value as well, and for that we just take the
845 // input layout and apply the masks.
846 // The layout of value is determined from the input layout by applying the
847 // new and shrink masks in the forwards direction as value would have been
848 // the output layout for a normal strided slice operation.
849 TF_ASSIGN_OR_RETURN(
850 const Layout proposed_value_layout,
851 ApplyNewAndShrinkMasksToLayout(strided_slice_op, input_rank, value_rank,
852 proposed_input_layout, FORWARD));
853 input_layouts[4] = proposed_value_layout;
854 }
855
856 return input_layouts;
857 }
858
ExpandOp(mlir::Operation * op)859 StatusOr<mlir::Operation*> StridedSliceGradSPMDExpander::ExpandOp(
860 mlir::Operation* op) {
861 auto strided_slice_grad_op = llvm::cast<mlir::TF::StridedSliceGradOp>(op);
862 TF_ASSIGN_OR_RETURN(
863 const Layout input_layout,
864 ExtractRequiredLayoutFromOperand(strided_slice_grad_op.dy()));
865 TF_ASSIGN_OR_RETURN(const Layout output_layout,
866 ExtractRequiredSingleLayoutFromOp(op));
867
868 TF_ASSIGN_OR_RETURN(
869 const llvm::ArrayRef<int64_t> global_output_shape,
870 GetGlobalShapeOfValueFromDTensorLayout(strided_slice_grad_op.output()));
871
872 const int input_rank = ValueRank(strided_slice_grad_op.dy());
873 const int output_rank = global_output_shape.size();
874
875 llvm::SmallVector<int64_t, 4> end;
876 TF_ASSIGN_OR_RETURN(
877 Layout intermediate_output_layout,
878 GetStridedSliceIntermediateLayout(strided_slice_grad_op, output_layout,
879 global_output_shape, &end));
880
881 TF_ASSIGN_OR_RETURN(Layout intermediate_input_layout,
882 ApplyNewAndShrinkMasksToLayout(
883 strided_slice_grad_op, output_rank, input_rank,
884 intermediate_output_layout, FORWARD));
885
886 TF_ASSIGN_OR_RETURN(mlir::Value new_dy,
887 EmitRelayout(strided_slice_grad_op.dy(), input_layout,
888 intermediate_input_layout));
889
890 strided_slice_grad_op.dyMutable().assign(new_dy);
891
892 mlir::OpBuilder builder(op);
893
894 if (!end.empty()) {
895 mlir::Value new_end =
896 IntConstWithMatchingType(builder, strided_slice_grad_op.getLoc(), end,
897 strided_slice_grad_op.begin().getType());
898 strided_slice_grad_op.endMutable().assign(new_end);
899 }
900
901 // The shape input to StridedSliceGrad will still be global, so we need to
902 // compute the local shape update it.
903 std::vector<int64_t> computed_output_shape =
904 intermediate_output_layout.LocalShapeFromGlobalShape(global_output_shape);
905 mlir::Value new_shape = IntConstWithMatchingType(
906 builder, strided_slice_grad_op.getLoc(), computed_output_shape,
907 strided_slice_grad_op.begin().getType());
908 strided_slice_grad_op.shapeMutable().assign(new_shape);
909
910 op = InferSPMDExpandedLocalShape(op);
911
912 // Do a final relayout to the correct output layout in case there are any
913 // differences between intermediate_output_layout and output_layout.
914 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
915
916 TF_ASSIGN_OR_RETURN(
917 mlir::Value output,
918 EmitRelayout(strided_slice_grad_op.output(), intermediate_output_layout,
919 output_layout, &newly_created_ops));
920
921 strided_slice_grad_op.output().replaceAllUsesExcept(output,
922 newly_created_ops);
923
924 return output.getDefiningOp();
925 }
926
927 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)928 StridedSliceGradSPMDExpander::ComputeLayoutForward(
929 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
930 // If the input layout is missing, don't return an output layout.
931 if (input_layouts.find(4) == input_layouts.end())
932 return llvm::DenseMap<int, Layout>();
933
934 mlir::TF::StridedSliceGradOp strided_slice_grad_op =
935 mlir::cast<mlir::TF::StridedSliceGradOp>(op);
936 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_output_shape,
937 GetShapeOfValue(strided_slice_grad_op.output(),
938 /*fail_on_dynamic=*/true));
939 const int input_rank = ValueRank(strided_slice_grad_op.dy());
940 const int output_rank = global_output_shape.size();
941
942 const Layout& input_layout = input_layouts.lookup(4);
943 // If dimension was added or removed, take the current output layout, and
944 // add/skip dimensions in it as needed to get an input layout.
945 TF_ASSIGN_OR_RETURN(
946 Layout proposed_layout,
947 ApplyNewAndShrinkMasksToLayout(strided_slice_grad_op, output_rank,
948 input_rank, input_layout, BACKWARD));
949 TF_ASSIGN_OR_RETURN(
950 proposed_layout,
951 GetStridedSliceIntermediateLayout(strided_slice_grad_op, proposed_layout,
952 global_output_shape));
953 return llvm::DenseMap<int, Layout>({{0, proposed_layout}});
954 }
955
956 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)957 StridedSliceGradSPMDExpander::ComputeLayoutBackward(
958 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
959 mlir::TF::StridedSliceGradOp strided_slice_grad_op =
960 mlir::cast<mlir::TF::StridedSliceGradOp>(op);
961 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
962
963 TF_ASSIGN_OR_RETURN(const llvm::ArrayRef<int64_t> global_output_shape,
964 GetShapeOfValue(strided_slice_grad_op.output(),
965 /*fail_on_dynamic=*/true));
966 const int input_rank = ValueRank(strided_slice_grad_op.dy());
967 const int output_rank = global_output_shape.size();
968
969 llvm::DenseMap<int, Layout> input_layouts(
970 strided_slice_grad_op.getNumOperands());
971 // Set replicated layout for shape, begin, end, stride operands.
972 input_layouts[0] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
973 input_layouts[1] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
974 input_layouts[2] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
975 input_layouts[3] = Layout::ReplicatedOnMesh(mesh, /*rank=*/1);
976
977 // dy
978 if (output_layouts.find(0) != output_layouts.end()) {
979 const Layout& output_layout = output_layouts.lookup(0);
980 TF_ASSIGN_OR_RETURN(
981 Layout proposed_layout,
982 GetStridedSliceIntermediateLayout(strided_slice_grad_op, output_layout,
983 global_output_shape));
984
985 // If dimension was added or removed, create a new proposed output layout
986 // with dimensions added/skipped.
987 TF_ASSIGN_OR_RETURN(
988 proposed_layout,
989 ApplyNewAndShrinkMasksToLayout(strided_slice_grad_op, output_rank,
990 input_rank, proposed_layout, FORWARD));
991 input_layouts[4] = proposed_layout;
992 }
993
994 return input_layouts;
995 }
996
997 } // namespace dtensor
998 } // namespace tensorflow
999