xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_slice.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
18 
19 #include <string>
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_slice.pb.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/gtl/inlined_vector.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 
30 // A tensor slice represents a slice of a given tensor. It is represented by a
31 // list of (start, length) pairs, where the size of the list is the rank of the
32 // tensor.
33 
34 class TensorSlice {
35  public:
36   // Construct a tensor slice: you have a number of ways:
37   // -- creating an empty slice
38   // -- from just a dimension (in this case it will create a full slice)
39   // -- from an array of pairs of integers.
40   // -- from a TensorSliceProto protocol buffer
41   // -- from a string format of "start,length:start,length..." where each
42   //    "start,length" pair represents the slice on one dimension. We allow a
43   //    special "-" that means "everything for this dimension". One such example
44   //    is:  0,10:-:14,1:-:-
TensorSlice()45   TensorSlice() {}
46   explicit TensorSlice(int dim);
47   explicit TensorSlice(const TensorSliceProto& proto);
48   explicit TensorSlice(
49       std::initializer_list<std::pair<int64_t, int64_t>> extents);
50 
51   // This factory methods should be used instead of the constructor that takes a
52   // `TensorSliceProto` if calling code cannot validate that the sizes specify a
53   // valid `TensorSlice`.
54   static Status BuildTensorSlice(const TensorSliceProto& proto,
55                                  TensorSlice* output);
56 
57   static Status Parse(const string& str, TensorSlice* output);
ParseOrDie(const string & str)58   static TensorSlice ParseOrDie(const string& str) {
59     TensorSlice ret;
60     Status s = Parse(str, &ret);
61     if (!s.ok()) {
62       LOG(FATAL) << "Could not parse TensorSlice";
63     }
64     return ret;
65   }
66 
67   void Clear();
68 
69   // Accessors
dims()70   int dims() const { return starts_.size(); }
71 
start(int d)72   int64_t start(int d) const {
73     DCHECK_GE(d, 0);
74     DCHECK_LT(d, dims());
75     return starts_[d];
76   }
77 
length(int d)78   int64_t length(int d) const {
79     DCHECK_GE(d, 0);
80     DCHECK_LT(d, dims());
81     return lengths_[d];
82   }
83 
end(int d)84   int64_t end(int d) const {
85     DCHECK_GE(d, 0);
86     DCHECK_LT(d, dims());
87     return start(d) + length(d);
88   }
89 
set_start(int d,int64_t x)90   void set_start(int d, int64_t x) {
91     DCHECK_GE(d, 0);
92     DCHECK_LT(d, dims());
93     DCHECK_GE(x, 0);
94     starts_[d] = x;
95   }
96 
set_length(int d,int64_t x)97   void set_length(int d, int64_t x) {
98     DCHECK_GE(d, 0);
99     DCHECK_LT(d, dims());
100     lengths_[d] = x;
101   }
102 
103   // If we have a full slice along dimension "d".
IsFullAt(int d)104   bool IsFullAt(int d) const {
105     return lengths_[d] == kFullExtent && starts_[d] == 0;
106   }
107 
108   // If this is a full slice, i.e. IsFullAt(d) for every d.
109   bool IsFull() const;
110 
111   // Set the slice to be a full slice of "dim" dimensions
112   void SetFullSlice(int dim);
113 
114   // Extend a slice to "dim" dimensions: all the added dimensions are full.
115   // Requires: dim >= dims().
116   void Extend(int dim);
117 
118   // Conversion of a TensorSlice to other formats
119   void AsProto(TensorSliceProto* proto) const;
120   string DebugString() const;
121 
122   // Fill *indices and *sizes from *this (so that we can use the slice()
123   // function in eigen tensor). We need a tensor shape in case some of the
124   // slices are full slices.
125   // We allow NDIMS to be greater than dims(), in which case we will pad the
126   // higher dimensions with trivial dimensions.
127   template <int NDIMS>
128   void FillIndicesAndSizes(
129       const TensorShape& shape,
130       Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
131       Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;
132 
133   // Interaction with other TensorSlices.
134 
135   // Compute the intersection with another slice and if "result" is not
136   // nullptr, store the results in *result; returns true if there is any real
137   // intersection.
138   bool Intersect(const TensorSlice& other, TensorSlice* result) const;
139   // A short hand.
Overlaps(const TensorSlice & other)140   bool Overlaps(const TensorSlice& other) const {
141     return Intersect(other, nullptr);
142   }
143 
144   // Equals iff "*this" and "other" are logically equivalent.
145   bool operator==(const TensorSlice& other) const;
146   bool operator!=(const TensorSlice& other) const { return !(*this == other); }
147 
148   // Interaction with TensorShape.
149 
150   // Slices a shape and stores the result into *result_shape.
151   // Requires that the shape and *this have the same rank.
152   // For example, given a tensor shape of {3, 4, 5}, and a slice of
153   // 1,2:-:0,2, the result shape is {2, 4, 2}.
154   Status SliceTensorShape(const TensorShape& shape,
155                           TensorShape* result_shape) const;
156 
157   // Given slice "sub" where "sub" is fully contained in *this,
158   // (meaning that the intersection of "sub" and *this equals "sub"), computes
159   // the "relative" slice of "sub" with respect to *this.
160   //
161   // In other words, if we use A>S to denote slicing a shape S with a slice A,
162   // then the function is computing a slice X such that:
163   //   X > (this > S) = sub > S
164   // for any shape S.
165   //
166   // In general, along every dimension, the start of the relative slice is the
167   // start of the "sub" slice minus the start of *this; the length of the
168   // relative slice is the length of the "sub" slice.
169   //
170   // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and
171   // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2.
172   //
173   // The caller needs to make sure that "sub" is indeed a sub-slice of *this;
174   // otherwise the result is undefined.
175   void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const;
176 
177   // Updates the slice in such a way that it fully covers "other" slice.
178   // Note, "other" slice should refer to the same tensor shape.
179   // Example:
180   //   given a slice [2:4, :, 3:] and "other" slice [:, 1:4, 2:4] the
181   //   updated slice would be [:, :, 2:]. Here is why:
182   //   dim 0: "2:4"  U  ":"    ->  ":"
183   //   dim 1: ":"    U  "1-4"  ->  ":"
184   //   dim 2: "3:"   U  "2:4"  ->  "2:"
185   void UpdateToCover(const TensorSlice& other);
186 
187   // Returns true if the length field was specified in an Extent.
188   static bool HasExtentLength(const TensorSliceProto::Extent& extent);
189 
190   // Returns the value of the length field in an Extent, or -1 if it
191   // is not present.
192   static int64_t GetExtentLength(const TensorSliceProto::Extent& extent);
193 
194  private:
195   // a length value of kFullExtent (-1) means we have a full slice at this
196   // dimension. It's defined in tensor_slice.cc.
197   static const int64_t kFullExtent;
198 
199   // TODO(yangke): switch to Eigen once it supports variable size arrays.
200   // A value of
201   gtl::InlinedVector<int64_t, 4> starts_;
202   gtl::InlinedVector<int64_t, 4> lengths_;
203 };
204 
205 template <int NDIMS>
FillIndicesAndSizes(const TensorShape & shape,Eigen::DSizes<Eigen::DenseIndex,NDIMS> * indices,Eigen::DSizes<Eigen::DenseIndex,NDIMS> * sizes)206 void TensorSlice::FillIndicesAndSizes(
207     const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
208     Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
209   CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
210                                  << "slices: shape = " << shape.DebugString()
211                                  << ", slice = " << DebugString();
212   CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from "
213                           << "a slice of dimension " << dims();
214   for (int d = 0; d < dims(); ++d) {
215     if (IsFullAt(d)) {
216       (*indices)[d] = 0;
217       (*sizes)[d] = shape.dim_size(d);
218     } else {
219       (*indices)[d] = starts_[d];
220       (*sizes)[d] = lengths_[d];
221     }
222   }
223   for (int d = dims(); d < NDIMS; ++d) {
224     (*indices)[d] = 0;
225     (*sizes)[d] = 1;
226   }
227 }
228 
229 }  // namespace tensorflow
230 
231 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
232