xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/strided_slice_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
16 #define TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
17 
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
23 
24 namespace tensorflow {
25 
26 struct StridedSliceShapeSpec {
27   // Begin mask canonlized in dense form.
28   int32_t begin_dense_mask;
29   // End mask canonlized in dense form.
30   int32_t end_dense_mask;
31   // Shrink axis mask canonlized in dense form.
32   int32_t shrink_axis_dense_mask;
33   // output_to_sparse_mapping[i] represents output[i]'s the corresponding dim
34   // index in the begin_tensor. If
35   // output_to_sparse_mapping[i] is -1, it means the dimension doesn't show up
36   // in sparse_mapping.
37   gtl::InlinedVector<int64_t, 4> output_to_sparse_mapping;
38   // output_to_processing_mapping is similar to output_to_sparse_mapping, but
39   // for processing shape.
40   gtl::InlinedVector<int64_t, 4> output_to_processing_mapping;
41   // processing_to_sparse_mapping[i] represents input_shape[i]'s corresponding
42   // dim index in the begin_tensor.
43   gtl::InlinedVector<int64_t, 4> processing_to_sparse_mapping;
44 };
45 
46 // Runs validation on the strided slice op parameters.
47 //
48 // Is a separate translation unit from the kernel so that:
49 // 1. The op's shape function can use it.
50 // 2. The code size is reduced vs templating this on the kernel's type.
51 //
52 // Note that when input_shape is not fully specified, only <final_shape> and
53 // <processing_shape> are valid; <is_identity>, <is_simple_slice> and other
54 // output parameters will not be accurate.
55 //
56 // If <begin_tensor> or <end_tensor> are nullptr, <begin> and <end> will not be
57 // valid. In this case, <slice_dim0> and <is_identity> will be true only if a
58 // determination can be made based on the information given. A best effort is
59 // made to set <processing_shape> and <final_shape> based on <input_shape>, but
60 // some dimensions of <processing_shape> and/or <final_shape> may be unknown
61 // (-1). Any validation that can be done without complete information is
62 // performed.
63 //
64 Status ValidateStridedSliceOp(
65     const Tensor* begin_tensor, const Tensor* end_tensor,
66     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
67     int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
68     int32_t new_axis_mask, int32_t shrink_axis_mask,
69     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
70     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
71     gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
72     gtl::InlinedVector<int64_t, 4>* strides,
73     StridedSliceShapeSpec* shape_spec = nullptr);
74 
75 // Same as above, but the outputs are TensorShape, not PartialTensorShape
76 Status ValidateStridedSliceOp(
77     const Tensor* begin_tensor, const Tensor* end_tensor,
78     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
79     int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
80     int32_t new_axis_mask, int32_t shrink_axis_mask,
81     TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity,
82     bool* is_simple_slice, bool* slice_dim0,
83     gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
84     gtl::InlinedVector<int64_t, 4>* strides,
85     StridedSliceShapeSpec* shape_spec = nullptr);
86 
87 // Simple class for determining if it is possible to broadcast a tensor to a
88 // strided slice.  Modelled after tensorflow::BCast, but with a few key
89 // differences:
90 // - the input_shape must be broadcastable to output_shape
91 //   (i.e. the slice shape does not grow).
92 // - does not allow reducing or flattening dimensions, since we cannot apply
93 //   these simplications to the destination slice.
94 // - allows for remapping dimensions, required in order to associate the input
95 //   with correct dimensions in the full (unsliced) destination tensor.
96 class StridedSliceAssignBCast {
97  public:
98   using Vec = gtl::InlinedVector<int64_t, 4>;
99 
100   StridedSliceAssignBCast(const Vec& input_shape, const Vec& output_shape);
101 
102   // Remaps broadcast, resize, and output dimensions via the provided map.
103   // Negative values in the map correspond to dimensions being removed.
104   // Unmapped dimensions are set to 1.
105   //
106   // This is to support remapping slice -> processing dimensions.  To relate
107   // the sliced output dimensions back to processing dimensions (i.e. those
108   // relative to the original unsliced input), we need to remove any axes
109   // that were added via the `new_axis_mask`, and add back any axes that were
110   // removed via the `shrink_axis_mask`.  For example, an expression like
111   //
112   // >>> t = tf.zeros([3, 3])
113   // >>> t[2, tf.newaxis, 0:2, tf.newaxis] = tf.ones([1, 3, 1])
114   //       ^                                          ^  ^  ^
115   //       |__ shrink axis                 new axis __|  |  |__ new axis
116   //                                                     |_____ dim 1 of t
117   //
118   // would have `new_axis_mask = 0b1010` and `shrink_axis_mask = 0b0001`. The
119   // slice has shape [1, 3, 1], but the original input tensor `t` has shape
120   // [3, 3]. To remap the slice dimensions back to the input dimensions, the
121   // mapping would use `num_dims = 2`, `dimension_map = {-1, 1, -1}`. This
122   // removes the two new axes added for the slice, maps the middle slice
123   // dimension to input dimension 1, and leaves input dimension 0 to have a
124   // default size of 1 to add back the shrink axis.
125   //
126   // Returns false if the remapping fails.
127   bool RemapDimensions(int64_t num_dims, const Vec& dimension_map);
128 
IsValid()129   bool IsValid() const { return valid_; }
130 
IsBroadcastingRequired()131   bool IsBroadcastingRequired() const { return broadcasting_required_; }
132 
reshape()133   const Vec& reshape() const { return reshape_; }
134 
bcast()135   const Vec& bcast() const { return bcast_; }
136 
result_shape()137   const Vec& result_shape() const { return result_shape_; }
138 
139  private:
140   bool valid_ = true;
141   bool broadcasting_required_ = false;
142   Vec reshape_;
143   Vec bcast_;
144   Vec result_shape_;
145 };
146 
147 }  // namespace tensorflow
148 
149 #endif  // TENSORFLOW_CORE_UTIL_STRIDED_SLICE_OP_H_
150