xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/sparse/sparse_tensor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18 
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/sparse/dim_comparator.h"
37 #include "tensorflow/core/util/sparse/group_iterator.h"
38 
39 namespace tensorflow {
40 namespace sparse {
41 
42 class SparseTensor {
43  public:
44   typedef typename gtl::ArraySlice<int64_t> VarDimArray;
45   typedef typename gtl::InlinedVector<int64_t, 8> ShapeArray;
46 
47   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
48                        const VarDimArray order, SparseTensor* result);
49 
50   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
51                        SparseTensor* result);
52 
53   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
54                        SparseTensor* result);
55 
56   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
57                        const VarDimArray order, SparseTensor* result);
58 
SparseTensor()59   SparseTensor() : dims_(0) {}
60 
61   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)62   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
63       : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
64                      UndefinedOrder(TensorShapeToVector(shape))) {}
65 
66   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)67   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
68       : SparseTensor(std::move(ix), std::move(vals), shape,
69                      UndefinedOrder(shape)) {}
70 
71   ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)72   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
73                const VarDimArray order)
74       : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
75                      order) {}
76 
77   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
78   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
79                const VarDimArray order);
80 
SparseTensor(const SparseTensor & other)81   SparseTensor(const SparseTensor& other)
82       : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
83 
SparseTensor(SparseTensor && other)84   SparseTensor(SparseTensor&& other)
85       : SparseTensor(std::move(other.ix_), std::move(other.vals_),
86                      std::move(other.shape_), std::move(other.order_)) {}
87 
88   SparseTensor& operator=(const SparseTensor& other) {
89     ix_ = other.ix_;
90     vals_ = other.vals_;
91     shape_ = other.shape_;
92     order_ = other.order_;
93     dims_ = other.dims_;
94     return *this;
95   }
96 
97   SparseTensor& operator=(SparseTensor&& other) {
98     ix_ = std::move(other.ix_);
99     vals_ = std::move(other.vals_);
100     shape_ = std::move(other.shape_);
101     order_ = std::move(other.order_);
102     dims_ = std::move(other.dims_);
103     return *this;
104   }
105 
num_entries()106   std::size_t num_entries() const { return ix_.dim_size(0); }
107 
dims()108   int dims() const { return shape_.size(); }
109 
indices()110   const Tensor& indices() const { return ix_; }
111 
values()112   const Tensor& values() const { return vals_; }
113 
dtype()114   DataType dtype() const { return vals_.dtype(); }
115 
116   Status IndicesValid() const;
117 
shape()118   VarDimArray shape() const { return shape_; }
119 
order()120   VarDimArray order() const { return order_; }
121 
122   // Resorts the indices and values according to the dimensions in order.
123   template <typename T>
124   void Reorder(const VarDimArray& order);
125 
126   // Returns a group iterable that can be used for clumping indices
127   // and values according to the group indices of interest.
128   //
129   // Precondition: order()[0..group_ix.size()] == group_ix.
130   //
131   // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)132   GroupIterable group(const VarDimArray& group_ix) const {
133     DCHECK_LE(group_ix.size(), dims_);
134     for (std::size_t di = 0; di < group_ix.size(); ++di) {
135       DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
136       DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
137       DCHECK_EQ(group_ix[di], order_[di])
138           << "Group dimension does not match sorted order";
139     }
140     return GroupIterable(ix_, vals_, dims_, group_ix);
141   }
142 
143   // Stores the sparse indices into the dense tensor out.
144   // Preconditions:
145   //   out->shape().dims() == shape().dims()
146   //   out->shape().dim_size(d) >= shape(d) for all d
147   //
148   // Returns true on success.  False on failure (mismatched dimensions
149   // or out-of-bounds indices).
150   //
151   // If initialize==True, ToDense first overwrites all coefficients in out to 0.
152   //
153   template <typename T>
154   bool ToDense(Tensor* out, bool initialize = true);
155 
156   // Concat() will concatenate all the tensors according to their first order
157   // dimension.  All tensors must have identical shape except for
158   // the first order dimension.  All tensors orders' first dimension
159   // must match.
160   //
161   // If all of the tensors have identical ordering, then the output
162   // will have this ordering.  Otherwise the output is set as not
163   // having any order and a Reorder<T>() should be called on it before
164   // performing any subsequent operations.
165   template <typename T>
166   static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
167 
168   // Split() will split the input SparseTensor into a list of num_split
169   // SparseTensor given a splitting dimension. If the input dimension range
170   // isn't an integer multiple of split_dim, we add one extra dimension for
171   // each slice.
172   template <typename T>
173   static Status Split(const SparseTensor& tensor, const int split_dim,
174                       const int num_split, std::vector<SparseTensor>* result);
175 
176   // Slice() will slice the input SparseTensor into a SparseTensor based on
177   // specified start and size. Both start and size are 1-D array with each
178   // element of the array representing one dimension. The start is the start
179   // index at each dimension and the size is the size at each dimension.
180   template <typename T>
181   static StatusOr<SparseTensor> Slice(const SparseTensor& tensor,
182                                       const gtl::ArraySlice<int64_t> start,
183                                       const gtl::ArraySlice<int64_t> size);
184 
185   // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64_t> dim_indices)186   std::vector<int64_t> PickDims(gtl::ArraySlice<int64_t> dim_indices) const {
187     std::vector<int64_t> res(dim_indices.size());
188     for (size_t i = 0; i < dim_indices.size(); ++i) {
189       res[i] = shape_[dim_indices[i]];
190     }
191     return res;
192   }
193 
194  private:
UndefinedOrder(const VarDimArray shape)195   static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
196     return ShapeArray(shape.size(), -1);
197   }
198 
TensorShapeToVector(const TensorShape & shape)199   static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
200     ShapeArray vec(shape.dims());
201     for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
202     return vec;
203   }
204 
205   // Optimized implementation of `IndicesValid` for 1-D sparse tensors.
206   // REQUIRES: `shape_.size() == 1`.
207   bool IndicesValidVectorFastPath() const;
208 
209   // Optimized implementation of `IndicesValid` for 2-D sparse tensors whose
210   // indices fit within the range of an `int32`.
211   // REQUIRES: `shape_.size() == 2`.
212   bool IndicesValidMatrix32BitFastPath() const;
213 
214   template <bool standard_order>
215   Status IndicesValidHelper() const;
216 
217   // Helper for ToDense<T>()
218   template <typename T>
219   bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
220 
221   // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)222   static inline int GetSliceIndex(const int dim, const int split_size,
223                                   const int residual) {
224     DCHECK_GT(split_size, 0);
225     DCHECK_GE(dim, 0);
226     if (residual == 0) return dim / split_size;
227     const int offset = residual * (split_size + 1);
228     if (dim < offset) {
229       return dim / (split_size + 1);
230     } else {
231       return residual + ((dim - offset) / split_size);
232     }
233   }
234 
235   // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)236   static inline int GetDimensionInSlice(const int dim, const int split_size,
237                                         const int residual) {
238     DCHECK_GT(split_size, 0);
239     DCHECK_GE(dim, 0);
240     if (residual == 0) return dim % split_size;
241     const int offset = residual * (split_size + 1);
242     if (dim < offset) {
243       return dim % (split_size + 1);
244     } else {
245       return (dim - offset) % split_size;
246     }
247   }
248 
249   // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)250   static inline int GetSliceShape(const int slice_index, const int split_size,
251                                   const int residual) {
252     DCHECK_GT(split_size, 0);
253     DCHECK_GE(slice_index, 0);
254     if (residual == 0) return split_size;
255     if (slice_index < residual) {
256       return split_size + 1;
257     } else {
258       return split_size;
259     }
260   }
261 
262   Tensor ix_;
263   Tensor vals_;
264   ShapeArray shape_;
265   ShapeArray order_;
266   int dims_;
267 };
268 
269 // This operation updates the indices and values Tensor rows, so it is
270 // an in-place algorithm.  It requires O(N log N) time and O(N)
271 // temporary space.
272 template <typename T>
Reorder(const VarDimArray & order)273 inline void SparseTensor::Reorder(const VarDimArray& order) {
274   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
275       << "Reorder requested with the wrong datatype";
276   DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
277   auto ix_t = ix_.matrix<int64_t>();
278   auto vals_t = vals_.vec<T>();
279 
280   std::vector<int64_t> reorder(num_entries());
281   std::iota(reorder.begin(), reorder.end(), 0);
282 
283   // Sort to get order of indices
284   switch (order.size()) {
285 #define CASE_SORT(ORDER_SIZE)                                    \
286   case ORDER_SIZE: {                                             \
287     FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
288     std::sort(reorder.begin(), reorder.end(), sorter);           \
289     break;                                                       \
290   }
291     CASE_SORT(0);
292     CASE_SORT(1);
293     CASE_SORT(2);
294     CASE_SORT(3);
295     CASE_SORT(4);
296     CASE_SORT(5);
297 #undef CASE_SORT
298     default: {
299       DimComparator sorter(ix_t, order, shape());
300       std::sort(reorder.begin(), reorder.end(), sorter);
301     }
302   }
303 
304   // We have a forward reordering, but what we'll need is a
305   // permutation (the inverse).  This can be calculated with O(1)
306   // additional
307   // and O(n) time (INVPERM) but we just do the simple thing here.
308   std::vector<size_t> permutation(reorder.size());
309   for (std::size_t n = 0; n < reorder.size(); ++n) {
310     permutation[reorder[n]] = n;
311   }
312 
313   // Update indices & values by converting the permutations to
314   // a product of transpositions.  Iterate over the cycles in the
315   // permutation, and convert each of those into a product of
316   // transpositions (swaps):
317   //   https://en.wikipedia.org/wiki/Cyclic_permutation
318   // This is N swaps, 2*N comparisons.
319   for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
320     while (n != permutation[n]) {
321       std::size_t r = permutation[n];
322       std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
323       std::swap(vals_t(n), vals_t(r));
324       std::swap(permutation[n], permutation[r]);
325     }
326   }
327 
328   order_ = ShapeArray(order.begin(), order.end());
329 }
330 
331 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)332 inline bool SparseTensor::ValidateAndInitializeToDense(Tensor* out,
333                                                        bool initialize) {
334   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
335       << "ToDense requested with the wrong datatype";
336 
337   DCHECK_EQ(out->shape().dims(), dims_)
338       << "Incompatible dimensions between SparseTensor and output";
339 
340   DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
341       << "Output must be type: " << DataTypeToEnum<T>::v()
342       << " but got: " << out->dtype();
343 
344   // Make sure the dense output is the same rank and has room
345   // to hold the SparseTensor.
346   const auto& out_shape = out->shape();
347   if (shape_.size() != out_shape.dims()) return false;
348   for (int d = 0; d < shape_.size(); ++d) {
349     if (shape_[d] > out_shape.dim_size(d)) return false;
350   }
351 
352   if (initialize) {
353     auto out_t = out->flat<T>();
354     out_t.setConstant(T());
355   }
356 
357   return true;
358 }
359 
360 template <typename T>
ToDense(Tensor * out,bool initialize)361 inline bool SparseTensor::ToDense(Tensor* out, bool initialize) {
362   if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
363 
364   auto out_t = out->flat<T>();
365   auto vals_t = vals_.vec<T>();
366   auto ix_t = ix_.matrix<int64_t>();
367   const int64_t* const ix_ptr = ix_t.data();
368 
369   if (dims_ == 1) {
370     // Fast path for sparse vectors.
371     const int64_t out_length = out->shape().dim_size(0);
372     for (int n = 0; n < vals_t.dimension(0); ++n) {
373       const int64_t index = internal::SubtleMustCopy(ix_ptr[n]);
374       if (!FastBoundsCheck(index, out_length)) return false;
375       out_t(index) = vals_t(n);
376     }
377     return true;
378   } else if (dims_ == 2) {
379     // Fast path for sparse matrices.
380     const auto& out_shape = out->shape();
381     const int64_t out_rows = out_shape.dim_size(0);
382     const int64_t out_cols = out_shape.dim_size(1);
383     for (int n = 0; n < vals_t.dimension(0); ++n) {
384       const int64_t row_index = internal::SubtleMustCopy(ix_ptr[n * 2]);
385       const int64_t col_index = internal::SubtleMustCopy(ix_ptr[n * 2 + 1]);
386       if (!(FastBoundsCheck(row_index, out_rows) &&
387             FastBoundsCheck(col_index, out_cols))) {
388         return false;
389       }
390       out_t(row_index * out_cols + col_index) = vals_t(n);
391     }
392     return true;
393   } else {
394     // General path for N-dimensional sparse tensors.
395     gtl::InlinedVector<int64_t, 4> strides(dims_);
396     const auto& out_shape = out->shape().dim_sizes();
397     if (dims_ > 0) {
398       strides[dims_ - 1] = 1;
399     }
400     for (int d = dims_ - 2; d >= 0; --d) {
401       strides[d] = strides[d + 1] * out_shape[d + 1];
402     }
403 
404     for (int n = 0; n < vals_t.dimension(0); ++n) {
405       bool invalid_dims = false;
406       int64_t ix = 0;
407       for (int d = 0; d < dims_; ++d) {
408         const int64_t ix_n_d = internal::SubtleMustCopy(ix_ptr[n * dims_ + d]);
409         if (!FastBoundsCheck(ix_n_d, out_shape[d])) {
410           invalid_dims = true;
411         }
412         ix += strides[d] * ix_n_d;
413       }
414       if (invalid_dims) return false;
415       out_t(ix) = vals_t(n);
416     }
417     return true;
418   }
419 }
420 
421 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)422 inline SparseTensor SparseTensor::Concat(
423     const gtl::ArraySlice<SparseTensor>& tensors) {
424   DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
425   const int dims = tensors[0].dims_;
426   DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
427   auto order_0 = tensors[0].order();
428   const int primary_dim = order_0[0];
429   ShapeArray final_order(order_0.begin(), order_0.end());
430   ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
431   final_shape[primary_dim] = 0;  // We'll build this up as we go along.
432   int num_entries = 0;
433 
434   bool fully_ordered = true;
435   for (const SparseTensor& st : tensors) {
436     DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
437     DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
438         << "Concat requested with the wrong data type";
439     DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
440     DCHECK_EQ(st.order()[0], primary_dim)
441         << "All SparseTensors' order[0] must match.  This is the concat dim.";
442     if (st.order() != final_order) fully_ordered = false;
443     const VarDimArray& st_shape = st.shape();
444     for (int d = 0; d < dims - 1; ++d) {
445       const int cdim = (d < primary_dim) ? d : d + 1;
446       DCHECK_EQ(final_shape[cdim], st_shape[cdim])
447           << "All SparseTensors' shapes must match except on the concat dim.  "
448           << "Concat dim: " << primary_dim
449           << ", mismatched shape at dim: " << cdim
450           << ".  Expecting shape like: [" << str_util::Join(final_shape, ",")
451           << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
452     }
453 
454     // Update dimension of final shape
455     final_shape[primary_dim] =
456         (final_shape[primary_dim] + st_shape[primary_dim]);
457 
458     num_entries += st.num_entries();  // Update number of entries
459   }
460 
461   // If nonconsistent ordering among inputs, set final order to -1s.
462   if (!fully_ordered) {
463     final_order = UndefinedOrder(final_shape);
464   }
465 
466   Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
467   Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
468 
469   TTypes<int64_t>::Matrix ix_t = output_ix.matrix<int64_t>();
470   typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
471 
472   Eigen::DenseIndex offset = 0;
473   int64_t shape_offset = 0;
474   for (const SparseTensor& st : tensors) {
475     const int st_num_entries = st.num_entries();
476 
477     // Fill in indices & values.
478     if (st_num_entries > 0) {
479       std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
480 
481       const auto* st_ix = &st.ix_.matrix<int64_t>()(0, 0);
482       auto* ix_out = &ix_t(offset, 0);
483       for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
484         *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
485       }
486     }
487 
488     offset += st_num_entries;
489     shape_offset += st.shape()[primary_dim];
490   }
491 
492   return SparseTensor(output_ix, output_vals, final_shape, final_order);
493 }
494 
495 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)496 inline Status SparseTensor::Split(const SparseTensor& input_tensor,
497                                   const int split_dim, const int num_split,
498                                   std::vector<SparseTensor>* result) {
499   std::vector<Tensor> output_indices;
500   std::vector<Tensor> output_values;
501   std::vector<TensorShape> output_shapes;
502   output_indices.reserve(num_split);
503   output_values.reserve(num_split);
504   output_shapes.reserve(num_split);
505 
506   std::vector<typename TTypes<int64_t>::Matrix> output_indices_t;
507   std::vector<typename TTypes<T>::Vec> output_values_t;
508   output_indices_t.reserve(num_split);
509   output_values_t.reserve(num_split);
510   auto input_values_t = input_tensor.values().vec<T>();
511   auto input_indices_t = input_tensor.indices().matrix<int64_t>();
512 
513   std::vector<int> num_values(num_split, 0);
514   const int num_dim = input_tensor.shape().size();
515   const int split_dim_size = input_tensor.shape()[split_dim];
516   const int split_size = split_dim_size / num_split;
517 
518   if (!(num_split > 0 && num_split <= split_dim_size)) {
519     return errors::InvalidArgument("num_split must be in the interval (0, ",
520                                    split_dim_size, "]");
521   }
522   if (!(split_dim >= 0 && split_dim < num_dim)) {
523     return errors::InvalidArgument("num_dim must be in the interval [0, ",
524                                    num_dim, ")");
525   }
526 
527   const int residual = split_dim_size % num_split;
528   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
529     const int dim = input_tensor.indices().matrix<int64_t>()(i, split_dim);
530     int slice_index = GetSliceIndex(dim, split_size, residual);
531     if (slice_index >= num_values.size()) {
532       return errors::InvalidArgument("Slice index ", slice_index,
533                                      " is larger than num_split.");
534     }
535     num_values[slice_index]++;
536   }
537 
538   for (int i = 0; i < num_split; ++i) {
539     // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
540     output_indices.emplace_back(DT_INT64,
541                                 TensorShape({num_values[i], num_dim}));
542     output_values.emplace_back(DataTypeToEnum<T>::v(),
543                                TensorShape({num_values[i]}));
544     output_shapes.emplace_back(input_tensor.shape());
545     output_indices_t.emplace_back(output_indices[i].matrix<int64_t>());
546     output_values_t.emplace_back(output_values[i].vec<T>());
547     const int size = GetSliceShape(i, split_size, residual);
548     output_shapes[i].set_dim(split_dim, size);
549   }
550 
551   std::vector<int> values_inserted_in_slice(num_split, 0);
552   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
553     const int dim = input_indices_t(i, split_dim);
554     const int slice_index = GetSliceIndex(dim, split_size, residual);
555     const int slice_dim = values_inserted_in_slice[slice_index]++;
556     output_values_t[slice_index](slice_dim) = input_values_t(i);
557     for (int j = 0; j < num_dim; ++j) {
558       const int64_t original_dim = input_indices_t(i, j);
559       output_indices_t[slice_index](slice_dim, j) =
560           (j == split_dim)
561               ? GetDimensionInSlice(original_dim, split_size, residual)
562               : original_dim;
563     }
564   }
565 
566   result->clear();
567   result->reserve(num_split);
568   for (int i = 0; i < num_split; ++i) {
569     SparseTensor tensor;
570     Status create_status =
571         Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
572     if (!create_status.ok()) {
573       return create_status;
574     }
575     result->push_back(std::move(tensor));
576   }
577   return OkStatus();
578 }
579 
580 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64_t> start,const gtl::ArraySlice<int64_t> size)581 inline StatusOr<SparseTensor> SparseTensor::Slice(
582     const SparseTensor& input_tensor, const gtl::ArraySlice<int64_t> start,
583     const gtl::ArraySlice<int64_t> size) {
584   TensorShape output_shape(input_tensor.shape());
585 
586   const int dims = input_tensor.dims();
587   for (int dim = 0; dim < dims; dim++) {
588     // Determine the size of the result; if the selected slice goes beyond the
589     // input boundary, the result will correspond to the size of the overlap
590     // between the input and the selected slice.
591     const int64_t input_size = output_shape.dim_size(dim);
592     const int64_t start_index = start[dim];
593     const int64_t slice_size = size[dim];
594 
595     if (start_index < input_size - slice_size) {
596       // The entire selection is within input boundaries.
597       TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, slice_size));
598     } else if (start_index < input_size) {
599       // The selection starts within input boundaries, but goes beyond them.
600       TF_RETURN_IF_ERROR(
601           output_shape.SetDimWithStatus(dim, input_size - start_index));
602     } else {
603       // The selection is entirely out of input boundaries.
604       TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, 0));
605     }
606   }
607 
608   auto input_indices_t = input_tensor.indices().matrix<int64_t>();
609   auto input_values_t = input_tensor.values().vec<T>();
610 
611   // Find the number of indices that fall inside start and size.
612   int count = 0;
613   for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
614     // The following will check to see if an input is within the
615     // range specified by start and size.
616     // The for loop below iterates through all dimensions. In case
617     // the index falls outside of the start and size at any dimension,
618     // it will be considered as a "no hit" (hit = false). In this
619     // case, it will not be counted as the index that fall inside
620     // the range specified by start and size.
621     bool hit = true;
622     for (int dim = 0; dim < dims; dim++) {
623       if (!(start[dim] <= input_indices_t(i, dim) &&
624             input_indices_t(i, dim) < start[dim] + size[dim])) {
625         hit = false;
626         break;
627       }
628     }
629     if (!hit) {
630       continue;
631     }
632     count++;
633   }
634 
635   Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
636   Tensor output_indices(DT_INT64, TensorShape({count, dims}));
637 
638   auto output_values_t = output_values.vec<T>();
639   auto output_indices_t = output_indices.matrix<int64_t>();
640 
641   // Obtain the output indices that fall inside start and size.
642   int index = 0;
643   for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
644        i++) {
645     // The logic here is similar as the above except that the above
646     // only count the number of indices while here we actually generate
647     // the output.
648     bool hit = true;
649     for (int dim = 0; dim < dims; dim++) {
650       if (!(start[dim] <= input_indices_t(i, dim) &&
651             input_indices_t(i, dim) < start[dim] + size[dim])) {
652         hit = false;
653         break;
654       }
655     }
656     if (!hit) {
657       continue;
658     }
659     output_values_t(index) = input_values_t(i);
660     for (int dim = 0; dim < dims; dim++) {
661       output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
662     }
663     index++;
664   }
665 
666   return SparseTensor(output_indices, output_values, output_shape);
667 }
668 
669 }  // namespace sparse
670 }  // namespace tensorflow
671 
672 #endif  // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
673