1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/sparse/dim_comparator.h"
37 #include "tensorflow/core/util/sparse/group_iterator.h"
38
39 namespace tensorflow {
40 namespace sparse {
41
42 class SparseTensor {
43 public:
44 typedef typename gtl::ArraySlice<int64_t> VarDimArray;
45 typedef typename gtl::InlinedVector<int64_t, 8> ShapeArray;
46
47 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
48 const VarDimArray order, SparseTensor* result);
49
50 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
51 SparseTensor* result);
52
53 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
54 SparseTensor* result);
55
56 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
57 const VarDimArray order, SparseTensor* result);
58
SparseTensor()59 SparseTensor() : dims_(0) {}
60
61 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)62 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
63 : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
64 UndefinedOrder(TensorShapeToVector(shape))) {}
65
66 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)67 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
68 : SparseTensor(std::move(ix), std::move(vals), shape,
69 UndefinedOrder(shape)) {}
70
71 ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)72 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
73 const VarDimArray order)
74 : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
75 order) {}
76
77 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
78 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
79 const VarDimArray order);
80
SparseTensor(const SparseTensor & other)81 SparseTensor(const SparseTensor& other)
82 : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
83
SparseTensor(SparseTensor && other)84 SparseTensor(SparseTensor&& other)
85 : SparseTensor(std::move(other.ix_), std::move(other.vals_),
86 std::move(other.shape_), std::move(other.order_)) {}
87
88 SparseTensor& operator=(const SparseTensor& other) {
89 ix_ = other.ix_;
90 vals_ = other.vals_;
91 shape_ = other.shape_;
92 order_ = other.order_;
93 dims_ = other.dims_;
94 return *this;
95 }
96
97 SparseTensor& operator=(SparseTensor&& other) {
98 ix_ = std::move(other.ix_);
99 vals_ = std::move(other.vals_);
100 shape_ = std::move(other.shape_);
101 order_ = std::move(other.order_);
102 dims_ = std::move(other.dims_);
103 return *this;
104 }
105
num_entries()106 std::size_t num_entries() const { return ix_.dim_size(0); }
107
dims()108 int dims() const { return shape_.size(); }
109
indices()110 const Tensor& indices() const { return ix_; }
111
values()112 const Tensor& values() const { return vals_; }
113
dtype()114 DataType dtype() const { return vals_.dtype(); }
115
116 Status IndicesValid() const;
117
shape()118 VarDimArray shape() const { return shape_; }
119
order()120 VarDimArray order() const { return order_; }
121
122 // Resorts the indices and values according to the dimensions in order.
123 template <typename T>
124 void Reorder(const VarDimArray& order);
125
126 // Returns a group iterable that can be used for clumping indices
127 // and values according to the group indices of interest.
128 //
129 // Precondition: order()[0..group_ix.size()] == group_ix.
130 //
131 // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)132 GroupIterable group(const VarDimArray& group_ix) const {
133 DCHECK_LE(group_ix.size(), dims_);
134 for (std::size_t di = 0; di < group_ix.size(); ++di) {
135 DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
136 DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
137 DCHECK_EQ(group_ix[di], order_[di])
138 << "Group dimension does not match sorted order";
139 }
140 return GroupIterable(ix_, vals_, dims_, group_ix);
141 }
142
143 // Stores the sparse indices into the dense tensor out.
144 // Preconditions:
145 // out->shape().dims() == shape().dims()
146 // out->shape().dim_size(d) >= shape(d) for all d
147 //
148 // Returns true on success. False on failure (mismatched dimensions
149 // or out-of-bounds indices).
150 //
151 // If initialize==True, ToDense first overwrites all coefficients in out to 0.
152 //
153 template <typename T>
154 bool ToDense(Tensor* out, bool initialize = true);
155
156 // Concat() will concatenate all the tensors according to their first order
157 // dimension. All tensors must have identical shape except for
158 // the first order dimension. All tensors orders' first dimension
159 // must match.
160 //
161 // If all of the tensors have identical ordering, then the output
162 // will have this ordering. Otherwise the output is set as not
163 // having any order and a Reorder<T>() should be called on it before
164 // performing any subsequent operations.
165 template <typename T>
166 static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
167
168 // Split() will split the input SparseTensor into a list of num_split
169 // SparseTensor given a splitting dimension. If the input dimension range
170 // isn't an integer multiple of split_dim, we add one extra dimension for
171 // each slice.
172 template <typename T>
173 static Status Split(const SparseTensor& tensor, const int split_dim,
174 const int num_split, std::vector<SparseTensor>* result);
175
176 // Slice() will slice the input SparseTensor into a SparseTensor based on
177 // specified start and size. Both start and size are 1-D array with each
178 // element of the array representing one dimension. The start is the start
179 // index at each dimension and the size is the size at each dimension.
180 template <typename T>
181 static StatusOr<SparseTensor> Slice(const SparseTensor& tensor,
182 const gtl::ArraySlice<int64_t> start,
183 const gtl::ArraySlice<int64_t> size);
184
185 // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64_t> dim_indices)186 std::vector<int64_t> PickDims(gtl::ArraySlice<int64_t> dim_indices) const {
187 std::vector<int64_t> res(dim_indices.size());
188 for (size_t i = 0; i < dim_indices.size(); ++i) {
189 res[i] = shape_[dim_indices[i]];
190 }
191 return res;
192 }
193
194 private:
UndefinedOrder(const VarDimArray shape)195 static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
196 return ShapeArray(shape.size(), -1);
197 }
198
TensorShapeToVector(const TensorShape & shape)199 static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
200 ShapeArray vec(shape.dims());
201 for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
202 return vec;
203 }
204
205 // Optimized implementation of `IndicesValid` for 1-D sparse tensors.
206 // REQUIRES: `shape_.size() == 1`.
207 bool IndicesValidVectorFastPath() const;
208
209 // Optimized implementation of `IndicesValid` for 2-D sparse tensors whose
210 // indices fit within the range of an `int32`.
211 // REQUIRES: `shape_.size() == 2`.
212 bool IndicesValidMatrix32BitFastPath() const;
213
214 template <bool standard_order>
215 Status IndicesValidHelper() const;
216
217 // Helper for ToDense<T>()
218 template <typename T>
219 bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
220
221 // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)222 static inline int GetSliceIndex(const int dim, const int split_size,
223 const int residual) {
224 DCHECK_GT(split_size, 0);
225 DCHECK_GE(dim, 0);
226 if (residual == 0) return dim / split_size;
227 const int offset = residual * (split_size + 1);
228 if (dim < offset) {
229 return dim / (split_size + 1);
230 } else {
231 return residual + ((dim - offset) / split_size);
232 }
233 }
234
235 // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)236 static inline int GetDimensionInSlice(const int dim, const int split_size,
237 const int residual) {
238 DCHECK_GT(split_size, 0);
239 DCHECK_GE(dim, 0);
240 if (residual == 0) return dim % split_size;
241 const int offset = residual * (split_size + 1);
242 if (dim < offset) {
243 return dim % (split_size + 1);
244 } else {
245 return (dim - offset) % split_size;
246 }
247 }
248
249 // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)250 static inline int GetSliceShape(const int slice_index, const int split_size,
251 const int residual) {
252 DCHECK_GT(split_size, 0);
253 DCHECK_GE(slice_index, 0);
254 if (residual == 0) return split_size;
255 if (slice_index < residual) {
256 return split_size + 1;
257 } else {
258 return split_size;
259 }
260 }
261
262 Tensor ix_;
263 Tensor vals_;
264 ShapeArray shape_;
265 ShapeArray order_;
266 int dims_;
267 };
268
269 // This operation updates the indices and values Tensor rows, so it is
270 // an in-place algorithm. It requires O(N log N) time and O(N)
271 // temporary space.
272 template <typename T>
Reorder(const VarDimArray & order)273 inline void SparseTensor::Reorder(const VarDimArray& order) {
274 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
275 << "Reorder requested with the wrong datatype";
276 DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
277 auto ix_t = ix_.matrix<int64_t>();
278 auto vals_t = vals_.vec<T>();
279
280 std::vector<int64_t> reorder(num_entries());
281 std::iota(reorder.begin(), reorder.end(), 0);
282
283 // Sort to get order of indices
284 switch (order.size()) {
285 #define CASE_SORT(ORDER_SIZE) \
286 case ORDER_SIZE: { \
287 FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
288 std::sort(reorder.begin(), reorder.end(), sorter); \
289 break; \
290 }
291 CASE_SORT(0);
292 CASE_SORT(1);
293 CASE_SORT(2);
294 CASE_SORT(3);
295 CASE_SORT(4);
296 CASE_SORT(5);
297 #undef CASE_SORT
298 default: {
299 DimComparator sorter(ix_t, order, shape());
300 std::sort(reorder.begin(), reorder.end(), sorter);
301 }
302 }
303
304 // We have a forward reordering, but what we'll need is a
305 // permutation (the inverse). This can be calculated with O(1)
306 // additional
307 // and O(n) time (INVPERM) but we just do the simple thing here.
308 std::vector<size_t> permutation(reorder.size());
309 for (std::size_t n = 0; n < reorder.size(); ++n) {
310 permutation[reorder[n]] = n;
311 }
312
313 // Update indices & values by converting the permutations to
314 // a product of transpositions. Iterate over the cycles in the
315 // permutation, and convert each of those into a product of
316 // transpositions (swaps):
317 // https://en.wikipedia.org/wiki/Cyclic_permutation
318 // This is N swaps, 2*N comparisons.
319 for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
320 while (n != permutation[n]) {
321 std::size_t r = permutation[n];
322 std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
323 std::swap(vals_t(n), vals_t(r));
324 std::swap(permutation[n], permutation[r]);
325 }
326 }
327
328 order_ = ShapeArray(order.begin(), order.end());
329 }
330
331 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)332 inline bool SparseTensor::ValidateAndInitializeToDense(Tensor* out,
333 bool initialize) {
334 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
335 << "ToDense requested with the wrong datatype";
336
337 DCHECK_EQ(out->shape().dims(), dims_)
338 << "Incompatible dimensions between SparseTensor and output";
339
340 DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
341 << "Output must be type: " << DataTypeToEnum<T>::v()
342 << " but got: " << out->dtype();
343
344 // Make sure the dense output is the same rank and has room
345 // to hold the SparseTensor.
346 const auto& out_shape = out->shape();
347 if (shape_.size() != out_shape.dims()) return false;
348 for (int d = 0; d < shape_.size(); ++d) {
349 if (shape_[d] > out_shape.dim_size(d)) return false;
350 }
351
352 if (initialize) {
353 auto out_t = out->flat<T>();
354 out_t.setConstant(T());
355 }
356
357 return true;
358 }
359
360 template <typename T>
ToDense(Tensor * out,bool initialize)361 inline bool SparseTensor::ToDense(Tensor* out, bool initialize) {
362 if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
363
364 auto out_t = out->flat<T>();
365 auto vals_t = vals_.vec<T>();
366 auto ix_t = ix_.matrix<int64_t>();
367 const int64_t* const ix_ptr = ix_t.data();
368
369 if (dims_ == 1) {
370 // Fast path for sparse vectors.
371 const int64_t out_length = out->shape().dim_size(0);
372 for (int n = 0; n < vals_t.dimension(0); ++n) {
373 const int64_t index = internal::SubtleMustCopy(ix_ptr[n]);
374 if (!FastBoundsCheck(index, out_length)) return false;
375 out_t(index) = vals_t(n);
376 }
377 return true;
378 } else if (dims_ == 2) {
379 // Fast path for sparse matrices.
380 const auto& out_shape = out->shape();
381 const int64_t out_rows = out_shape.dim_size(0);
382 const int64_t out_cols = out_shape.dim_size(1);
383 for (int n = 0; n < vals_t.dimension(0); ++n) {
384 const int64_t row_index = internal::SubtleMustCopy(ix_ptr[n * 2]);
385 const int64_t col_index = internal::SubtleMustCopy(ix_ptr[n * 2 + 1]);
386 if (!(FastBoundsCheck(row_index, out_rows) &&
387 FastBoundsCheck(col_index, out_cols))) {
388 return false;
389 }
390 out_t(row_index * out_cols + col_index) = vals_t(n);
391 }
392 return true;
393 } else {
394 // General path for N-dimensional sparse tensors.
395 gtl::InlinedVector<int64_t, 4> strides(dims_);
396 const auto& out_shape = out->shape().dim_sizes();
397 if (dims_ > 0) {
398 strides[dims_ - 1] = 1;
399 }
400 for (int d = dims_ - 2; d >= 0; --d) {
401 strides[d] = strides[d + 1] * out_shape[d + 1];
402 }
403
404 for (int n = 0; n < vals_t.dimension(0); ++n) {
405 bool invalid_dims = false;
406 int64_t ix = 0;
407 for (int d = 0; d < dims_; ++d) {
408 const int64_t ix_n_d = internal::SubtleMustCopy(ix_ptr[n * dims_ + d]);
409 if (!FastBoundsCheck(ix_n_d, out_shape[d])) {
410 invalid_dims = true;
411 }
412 ix += strides[d] * ix_n_d;
413 }
414 if (invalid_dims) return false;
415 out_t(ix) = vals_t(n);
416 }
417 return true;
418 }
419 }
420
421 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)422 inline SparseTensor SparseTensor::Concat(
423 const gtl::ArraySlice<SparseTensor>& tensors) {
424 DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
425 const int dims = tensors[0].dims_;
426 DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
427 auto order_0 = tensors[0].order();
428 const int primary_dim = order_0[0];
429 ShapeArray final_order(order_0.begin(), order_0.end());
430 ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
431 final_shape[primary_dim] = 0; // We'll build this up as we go along.
432 int num_entries = 0;
433
434 bool fully_ordered = true;
435 for (const SparseTensor& st : tensors) {
436 DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
437 DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
438 << "Concat requested with the wrong data type";
439 DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
440 DCHECK_EQ(st.order()[0], primary_dim)
441 << "All SparseTensors' order[0] must match. This is the concat dim.";
442 if (st.order() != final_order) fully_ordered = false;
443 const VarDimArray& st_shape = st.shape();
444 for (int d = 0; d < dims - 1; ++d) {
445 const int cdim = (d < primary_dim) ? d : d + 1;
446 DCHECK_EQ(final_shape[cdim], st_shape[cdim])
447 << "All SparseTensors' shapes must match except on the concat dim. "
448 << "Concat dim: " << primary_dim
449 << ", mismatched shape at dim: " << cdim
450 << ". Expecting shape like: [" << str_util::Join(final_shape, ",")
451 << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
452 }
453
454 // Update dimension of final shape
455 final_shape[primary_dim] =
456 (final_shape[primary_dim] + st_shape[primary_dim]);
457
458 num_entries += st.num_entries(); // Update number of entries
459 }
460
461 // If nonconsistent ordering among inputs, set final order to -1s.
462 if (!fully_ordered) {
463 final_order = UndefinedOrder(final_shape);
464 }
465
466 Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
467 Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
468
469 TTypes<int64_t>::Matrix ix_t = output_ix.matrix<int64_t>();
470 typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
471
472 Eigen::DenseIndex offset = 0;
473 int64_t shape_offset = 0;
474 for (const SparseTensor& st : tensors) {
475 const int st_num_entries = st.num_entries();
476
477 // Fill in indices & values.
478 if (st_num_entries > 0) {
479 std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
480
481 const auto* st_ix = &st.ix_.matrix<int64_t>()(0, 0);
482 auto* ix_out = &ix_t(offset, 0);
483 for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
484 *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
485 }
486 }
487
488 offset += st_num_entries;
489 shape_offset += st.shape()[primary_dim];
490 }
491
492 return SparseTensor(output_ix, output_vals, final_shape, final_order);
493 }
494
495 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)496 inline Status SparseTensor::Split(const SparseTensor& input_tensor,
497 const int split_dim, const int num_split,
498 std::vector<SparseTensor>* result) {
499 std::vector<Tensor> output_indices;
500 std::vector<Tensor> output_values;
501 std::vector<TensorShape> output_shapes;
502 output_indices.reserve(num_split);
503 output_values.reserve(num_split);
504 output_shapes.reserve(num_split);
505
506 std::vector<typename TTypes<int64_t>::Matrix> output_indices_t;
507 std::vector<typename TTypes<T>::Vec> output_values_t;
508 output_indices_t.reserve(num_split);
509 output_values_t.reserve(num_split);
510 auto input_values_t = input_tensor.values().vec<T>();
511 auto input_indices_t = input_tensor.indices().matrix<int64_t>();
512
513 std::vector<int> num_values(num_split, 0);
514 const int num_dim = input_tensor.shape().size();
515 const int split_dim_size = input_tensor.shape()[split_dim];
516 const int split_size = split_dim_size / num_split;
517
518 if (!(num_split > 0 && num_split <= split_dim_size)) {
519 return errors::InvalidArgument("num_split must be in the interval (0, ",
520 split_dim_size, "]");
521 }
522 if (!(split_dim >= 0 && split_dim < num_dim)) {
523 return errors::InvalidArgument("num_dim must be in the interval [0, ",
524 num_dim, ")");
525 }
526
527 const int residual = split_dim_size % num_split;
528 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
529 const int dim = input_tensor.indices().matrix<int64_t>()(i, split_dim);
530 int slice_index = GetSliceIndex(dim, split_size, residual);
531 if (slice_index >= num_values.size()) {
532 return errors::InvalidArgument("Slice index ", slice_index,
533 " is larger than num_split.");
534 }
535 num_values[slice_index]++;
536 }
537
538 for (int i = 0; i < num_split; ++i) {
539 // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
540 output_indices.emplace_back(DT_INT64,
541 TensorShape({num_values[i], num_dim}));
542 output_values.emplace_back(DataTypeToEnum<T>::v(),
543 TensorShape({num_values[i]}));
544 output_shapes.emplace_back(input_tensor.shape());
545 output_indices_t.emplace_back(output_indices[i].matrix<int64_t>());
546 output_values_t.emplace_back(output_values[i].vec<T>());
547 const int size = GetSliceShape(i, split_size, residual);
548 output_shapes[i].set_dim(split_dim, size);
549 }
550
551 std::vector<int> values_inserted_in_slice(num_split, 0);
552 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
553 const int dim = input_indices_t(i, split_dim);
554 const int slice_index = GetSliceIndex(dim, split_size, residual);
555 const int slice_dim = values_inserted_in_slice[slice_index]++;
556 output_values_t[slice_index](slice_dim) = input_values_t(i);
557 for (int j = 0; j < num_dim; ++j) {
558 const int64_t original_dim = input_indices_t(i, j);
559 output_indices_t[slice_index](slice_dim, j) =
560 (j == split_dim)
561 ? GetDimensionInSlice(original_dim, split_size, residual)
562 : original_dim;
563 }
564 }
565
566 result->clear();
567 result->reserve(num_split);
568 for (int i = 0; i < num_split; ++i) {
569 SparseTensor tensor;
570 Status create_status =
571 Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
572 if (!create_status.ok()) {
573 return create_status;
574 }
575 result->push_back(std::move(tensor));
576 }
577 return OkStatus();
578 }
579
580 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64_t> start,const gtl::ArraySlice<int64_t> size)581 inline StatusOr<SparseTensor> SparseTensor::Slice(
582 const SparseTensor& input_tensor, const gtl::ArraySlice<int64_t> start,
583 const gtl::ArraySlice<int64_t> size) {
584 TensorShape output_shape(input_tensor.shape());
585
586 const int dims = input_tensor.dims();
587 for (int dim = 0; dim < dims; dim++) {
588 // Determine the size of the result; if the selected slice goes beyond the
589 // input boundary, the result will correspond to the size of the overlap
590 // between the input and the selected slice.
591 const int64_t input_size = output_shape.dim_size(dim);
592 const int64_t start_index = start[dim];
593 const int64_t slice_size = size[dim];
594
595 if (start_index < input_size - slice_size) {
596 // The entire selection is within input boundaries.
597 TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, slice_size));
598 } else if (start_index < input_size) {
599 // The selection starts within input boundaries, but goes beyond them.
600 TF_RETURN_IF_ERROR(
601 output_shape.SetDimWithStatus(dim, input_size - start_index));
602 } else {
603 // The selection is entirely out of input boundaries.
604 TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, 0));
605 }
606 }
607
608 auto input_indices_t = input_tensor.indices().matrix<int64_t>();
609 auto input_values_t = input_tensor.values().vec<T>();
610
611 // Find the number of indices that fall inside start and size.
612 int count = 0;
613 for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
614 // The following will check to see if an input is within the
615 // range specified by start and size.
616 // The for loop below iterates through all dimensions. In case
617 // the index falls outside of the start and size at any dimension,
618 // it will be considered as a "no hit" (hit = false). In this
619 // case, it will not be counted as the index that fall inside
620 // the range specified by start and size.
621 bool hit = true;
622 for (int dim = 0; dim < dims; dim++) {
623 if (!(start[dim] <= input_indices_t(i, dim) &&
624 input_indices_t(i, dim) < start[dim] + size[dim])) {
625 hit = false;
626 break;
627 }
628 }
629 if (!hit) {
630 continue;
631 }
632 count++;
633 }
634
635 Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
636 Tensor output_indices(DT_INT64, TensorShape({count, dims}));
637
638 auto output_values_t = output_values.vec<T>();
639 auto output_indices_t = output_indices.matrix<int64_t>();
640
641 // Obtain the output indices that fall inside start and size.
642 int index = 0;
643 for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
644 i++) {
645 // The logic here is similar as the above except that the above
646 // only count the number of indices while here we actually generate
647 // the output.
648 bool hit = true;
649 for (int dim = 0; dim < dims; dim++) {
650 if (!(start[dim] <= input_indices_t(i, dim) &&
651 input_indices_t(i, dim) < start[dim] + size[dim])) {
652 hit = false;
653 break;
654 }
655 }
656 if (!hit) {
657 continue;
658 }
659 output_values_t(index) = input_values_t(i);
660 for (int dim = 0; dim < dims; dim++) {
661 output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
662 }
663 index++;
664 }
665
666 return SparseTensor(output_indices, output_values, output_shape);
667 }
668
669 } // namespace sparse
670 } // namespace tensorflow
671
672 #endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
673