1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/strided_slice_op.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <iterator>
21
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/lib/core/status.h"
24
25 namespace tensorflow {
26 namespace {
27
28 /// Constants
29 constexpr int32_t kShrinkAxis = -1, kNewAxis = -2;
30
31 // Sparse slicing specification
32 // if one does foo[3:5, ..., -3], this will have 3 length tensors
33 struct StridedSliceSparseSpec {
34 int64_t dims;
35 int32 num_add_axis_after_ellipsis;
36 const Tensor* begin_tensor;
37 const Tensor* end_tensor;
38 const Tensor& strides_tensor;
39 const int32 begin_mask, end_mask;
40 int32 ellipsis_mask;
41 const int32 new_axis_mask, shrink_axis_mask;
42 };
43
44 // Dense slicing specification
45 // all ellipses and newaxis' are expanded out. So if
46 // foo[3:5, ..., -3] where foo is 10 dimensional,
47 // each inlinedVector will have 10 entries whereas the
48 // sparse had 3 length tensors.
49 struct StridedSliceDenseSpec {
50 const int64_t dims;
51 int32 begin_mask;
52 int32 end_mask;
53 bool begin_valid;
54 bool end_valid;
55 gtl::InlinedVector<int64_t, 4>& begin;
56 gtl::InlinedVector<int64_t, 4>& end;
57 gtl::InlinedVector<int64_t, 4>& strides;
58 // This vector helps construct the final shape of the slice.
59 // The final tensor is reduced in rank whenever a single index e.g. foo[3]
60 // is called for. The final tensor increases in rank with tf.newaxis
61 // entries. If an index in this array is positive, the size of the dimension
62 // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
63 // it will be 1. A shrunk dimension is skipped.
64 gtl::InlinedVector<int32, 4> final_shape_gather_indices;
65 // This vector has the same size as final_shape_gather_indices, but it
66 // remembers the sparse index that a dimension comes from, instead of dense
67 // index. A -1 in this vector means there the index is not from the sparse
68 // input.
69 gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
70 gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
71 // The dense indexed shrink mask is which processing dimensions
72 // should be shrunk. For example, if foo.shape = (10,10,10,10)
73 // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
74 // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
75 int32 shrink_axis_mask;
76 };
77
78 } // namespace
79
80 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)81 static Status TF_MUST_USE_RESULT BuildDenseSpec(
82 const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
83 // Build expanded begin, end, strides, begin_mask, end_mask
84 // to remove any ellipsis
85 dense->begin.resize(dense->dims);
86 dense->end.resize(dense->dims);
87 dense->strides.resize(dense->dims);
88 dense->input_shape_gather_indices_sparse.resize(dense->dims);
89 // What indices to get the final shape from.
90 dense->begin_mask = 0;
91 dense->end_mask = 0;
92 dense->shrink_axis_mask = 0;
93 {
94 int full_index = 0;
95
96 const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
97 dense->begin_valid = sparse.begin_tensor != nullptr;
98 dense->end_valid = sparse.end_tensor != nullptr;
99
100 const T* const begin_flat = sparse.begin_tensor != nullptr
101 ? sparse.begin_tensor->vec<T>().data()
102 : nullptr;
103 const T* const end_flat = sparse.end_tensor != nullptr
104 ? sparse.end_tensor->vec<T>().data()
105 : nullptr;
106
107 for (int i = 0; i < sparse.dims; i++) {
108 if ((1 << i) & sparse.ellipsis_mask) {
109 // Expand the ellipsis into the appropriate indices
110 // NOTE: this only works because we guaranteed one ellipsis
111 int32_t next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
112 sparse.num_add_axis_after_ellipsis,
113 dense->dims);
114 for (; full_index < next_index; full_index++) {
115 // new_axis' aren't real axis so you have to skip
116 dense->begin[full_index] = dense->end[full_index] = 0;
117 dense->strides[full_index] = 1;
118 dense->begin_mask |= (1 << full_index);
119 dense->end_mask |= (1 << full_index);
120 dense->final_shape_gather_indices.push_back(full_index);
121 dense->final_shape_gather_indices_sparse.push_back(-1);
122 dense->input_shape_gather_indices_sparse[full_index] = i;
123 }
124 } else if ((1 << i) & sparse.new_axis_mask) {
125 dense->final_shape_gather_indices.push_back(kNewAxis);
126 dense->final_shape_gather_indices_sparse.push_back(-1);
127 } else {
128 if (full_index == dense->begin.size()) {
129 return errors::InvalidArgument("Index out of range using input dim ",
130 full_index, "; input has only ",
131 dense->dims, " dims");
132 }
133
134 // Gather slicing spec into appropriate index
135 if (begin_flat != nullptr) {
136 dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
137 }
138 if (end_flat != nullptr) {
139 dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
140 }
141 dense->strides[full_index] =
142 internal::SubtleMustCopy<T>(strides_flat[i]);
143 if (sparse.begin_mask & (1 << i)) {
144 dense->begin_mask |= (1 << full_index);
145 }
146 if (sparse.end_mask & (1 << i)) {
147 dense->end_mask |= (1 << full_index);
148 }
149 // If shrink, record where to get the dimensionality from (i.e.
150 // new_axis creates a fake 1 size dimension. Also remember shrink
151 // axis (now in dense form) so we can ignore dense->end below.
152 if (sparse.shrink_axis_mask & (1 << i)) {
153 dense->final_shape_gather_indices.push_back(kShrinkAxis);
154 dense->final_shape_gather_indices_sparse.push_back(-1);
155 dense->shrink_axis_mask |= (1 << full_index);
156 } else {
157 dense->final_shape_gather_indices.push_back(full_index);
158 // Remember that where in the sparse shape the dense dim comes
159 // from.
160 dense->final_shape_gather_indices_sparse.push_back(i);
161 }
162 dense->input_shape_gather_indices_sparse[full_index] = i;
163 full_index++;
164 }
165 }
166 }
167 return OkStatus();
168 }
169
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32_t begin_mask_spec,int32_t end_mask_spec,const int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64_t,4> * begin,gtl::InlinedVector<int64_t,4> * end,gtl::InlinedVector<int64_t,4> * strides,StridedSliceShapeSpec * shape_spec)170 Status ValidateStridedSliceOp(
171 const Tensor* begin_tensor, const Tensor* end_tensor,
172 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
173 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
174 int32_t new_axis_mask, int32_t shrink_axis_mask,
175 PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
176 bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
177 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
178 gtl::InlinedVector<int64_t, 4>* strides,
179 StridedSliceShapeSpec* shape_spec) {
180 const bool begin_is_wrong =
181 begin_tensor != nullptr &&
182 !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
183 begin_tensor->NumElements() == strides_tensor.NumElements() &&
184 begin_tensor->NumElements() < 32 /* using 32 bit masks */);
185 const bool end_is_wrong =
186 end_tensor != nullptr &&
187 !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
188 end_tensor->NumElements() == strides_tensor.NumElements());
189 if (begin_is_wrong || end_is_wrong ||
190 !TensorShapeUtils::IsVector(strides_tensor.shape())) {
191 if (begin_tensor != nullptr && end_tensor != nullptr) {
192 return errors::InvalidArgument(
193 "Expected begin, end, and strides to be 1D equal size tensors, ",
194 "but got shapes ", begin_tensor->shape().DebugString(), ", ",
195 end_tensor->shape().DebugString(), ", and ",
196 strides_tensor.shape().DebugString(), " instead.");
197 } else {
198 return errors::InvalidArgument(
199 "Expected begin, end, and strides to be 1D equal size tensors, ",
200 "but got shape ", strides_tensor.shape().DebugString(),
201 " for strides.");
202 }
203 }
204 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
205 // i.e. there exists only no more than one ellipsis
206 if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
207 return errors::InvalidArgument(
208 "Multiple ellipses in slice spec not allowed");
209 }
210
211 // Step 1: Account for ellipsis and new axis
212 //
213 // Check for ellipses and count how many non-newaxis' there are after
214 // TODO(aselle): Convert this to do a fast log2 followed by iteration
215 // counting ones in next guys
216 bool ellipsis_seen = false;
217
218 StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
219 0,
220 begin_tensor,
221 end_tensor,
222 strides_tensor,
223 begin_mask_spec,
224 end_mask_spec,
225 ellipsis_mask,
226 new_axis_mask,
227 shrink_axis_mask};
228
229 for (int32_t i = 0; i < sparse_spec.dims; i++) {
230 if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
231 sparse_spec.num_add_axis_after_ellipsis++;
232 }
233 if ((1 << i) & ellipsis_mask) {
234 ellipsis_seen = true;
235 }
236 }
237 // If no ellipsis insert one at the end
238 if (!ellipsis_seen) {
239 sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
240 sparse_spec.dims++; // this effects loop iteration below
241 }
242
243 // Step 2: Make a sparse spec into a full index spec
244 //
245 // The sparse spec does not correspond to the number of dimensions
246 // Make a dense spec that corresponds to the number of dimensions
247 //
248 // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
249 // we need to produce the missing begin_mask for the first two
250 // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
251 // we achieve begin_mask=6, end_mask=7
252 StridedSliceDenseSpec dense_spec = {input_shape.dims(),
253 0 /* begin_mask */,
254 0 /* end_mask */,
255 false /* begin_valid */,
256 false /* end_valid */,
257 *begin,
258 *end,
259 *strides};
260
261 if (strides_tensor.dtype() == DT_INT32) {
262 TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
263 } else if (strides_tensor.dtype() == DT_INT64) {
264 TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec));
265 } else if (strides_tensor.dtype() == DT_INT16) {
266 TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec));
267 } else {
268 LOG(FATAL) << "begin must be either int16, int32 or int64";
269 }
270
271 // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
272 // and bounds check!
273 *is_identity = true;
274 *slice_dim0 = true;
275 *is_simple_slice = true;
276 processing_shape->Clear();
277 for (int i = 0; i < input_shape.dims(); ++i) {
278 int64_t& begin_i = (*begin)[i];
279 int64_t& end_i = (*end)[i];
280 int64_t& stride_i = (*strides)[i];
281 int64_t dim_i = input_shape.dim_size(i);
282 if (stride_i == 0) {
283 return errors::InvalidArgument("strides[", i, "] must be non-zero");
284 }
285 bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
286 if (dim_i == -1) {
287 processing_shape->AddDim(shrink_i ? 1 : -1);
288 continue;
289 }
290
291 const std::array<int64_t, 2> masks = {
292 {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
293 const std::array<int64_t, 2> valid_range = {
294 {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
295
296 auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) {
297 if (masks[c]) {
298 return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
299 } else {
300 int64_t x_fwd =
301 x < 0 ? dim_i + x : x; // make negative indices positive
302 return x_fwd < valid_range[0]
303 ? valid_range[0]
304 : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
305 }
306 };
307 if (shrink_i && stride_i <= 0) {
308 return errors::InvalidArgument(
309 "only stride 1 allowed on non-range indexing.");
310 }
311 (*is_simple_slice) &= stride_i == 1;
312
313 const bool begin_and_end_masked =
314 (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
315 if (dense_spec.begin_valid && dense_spec.end_valid) {
316 if (shrink_i) {
317 // If we are shrinking, the end index is now possibly incorrect. In
318 // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
319 // and canonical puts these to n-1 and 0, which implies a degenerate
320 // interval. Fortunately, it is now safe to re-create end as begin+1.
321 int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
322 begin_i = x_fwd;
323 end_i = begin_i + 1;
324 if (x_fwd < 0 || x_fwd >= dim_i) {
325 return errors::InvalidArgument(
326 "slice index ", begin_i, " of dimension ", i, " out of bounds.");
327 }
328 } else {
329 begin_i = canonical(begin_i, 0);
330 end_i = canonical(end_i, 1);
331 }
332 // Update optimization values
333 bool take_all_in_dimension =
334 stride_i == 1 && begin_i == 0 && end_i == dim_i;
335 (*is_identity) &= take_all_in_dimension;
336 (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
337 } else {
338 (*is_identity) &= stride_i == 1 && begin_and_end_masked;
339 (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
340 }
341 // Compute the processing shape (the intermediate Eigen will produce)
342 int64_t interval_length;
343 bool known_interval = false;
344 if (dense_spec.begin_valid && dense_spec.end_valid) {
345 interval_length = end_i - begin_i;
346 known_interval = true;
347 } else if (shrink_i) {
348 // The dimension is still known as 1 for the processing_shape, but will be
349 // discarded for the final shape.
350 interval_length = 1;
351 known_interval = true;
352 } else if (begin_and_end_masked) {
353 // Even if we don't have values for begin or end, we do know that this
354 // dimension covers the whole interval. If we have shape information for
355 // this dimension, that tells us the interval length.
356 if (dim_i >= 0) {
357 if (stride_i < 0) {
358 interval_length = -dim_i;
359 } else {
360 interval_length = dim_i;
361 }
362 known_interval = true;
363 }
364 }
365 if (known_interval) {
366 int64_t size_i;
367 // Hold zero if the interval is degenerate, otherwise account for
368 // remainder
369 if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
370 size_i = 0;
371 } else {
372 size_i = interval_length / stride_i +
373 (interval_length % stride_i != 0 ? 1 : 0);
374 }
375 processing_shape->AddDim(size_i);
376 } else {
377 processing_shape->AddDim(-1);
378 }
379 }
380
381 // Step 4: Compute the final shape
382 //
383 // new_axis will increase dimension by 1 (with a one-size dimension)
384 // slices like foo[3,...] will reduce dimension by 1.
385 // This cannot be done earlier, because it depends on Step 3.
386 final_shape->Clear();
387 if (shape_spec != nullptr) {
388 shape_spec->output_to_sparse_mapping.clear();
389 shape_spec->output_to_processing_mapping.clear();
390 shape_spec->processing_to_sparse_mapping.assign(
391 dense_spec.input_shape_gather_indices_sparse.begin(),
392 dense_spec.input_shape_gather_indices_sparse.end());
393
394 shape_spec->begin_dense_mask = dense_spec.begin_mask;
395 shape_spec->end_dense_mask = dense_spec.end_mask;
396 shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
397 }
398
399 for (int64_t dense_dim = 0;
400 dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
401 int64_t gather_index = dense_spec.final_shape_gather_indices[dense_dim];
402 int64_t sparse_index =
403 dense_spec.final_shape_gather_indices_sparse[dense_dim];
404 if (gather_index >= 0) {
405 final_shape->AddDim(processing_shape->dim_size(gather_index));
406 if (shape_spec != nullptr) {
407 shape_spec->output_to_sparse_mapping.push_back(sparse_index);
408 shape_spec->output_to_processing_mapping.push_back(gather_index);
409 }
410 } else if (gather_index == kNewAxis) {
411 final_shape->AddDim(1);
412 if (shape_spec != nullptr) {
413 shape_spec->output_to_sparse_mapping.push_back(-1);
414 shape_spec->output_to_processing_mapping.push_back(-1);
415 }
416 }
417 }
418
419 return OkStatus();
420 }
421
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32_t begin_mask_spec,int32_t end_mask_spec,const int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64_t,4> * begin,gtl::InlinedVector<int64_t,4> * end,gtl::InlinedVector<int64_t,4> * strides,StridedSliceShapeSpec * shape_spec)422 Status ValidateStridedSliceOp(
423 const Tensor* begin_tensor, const Tensor* end_tensor,
424 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
425 int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
426 int32_t new_axis_mask, int32_t shrink_axis_mask,
427 TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity,
428 bool* is_simple_slice, bool* slice_dim0,
429 gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
430 gtl::InlinedVector<int64_t, 4>* strides,
431 StridedSliceShapeSpec* shape_spec) {
432 // Validate with PartialTensorShape output
433 PartialTensorShape partial_processing_shape, partial_final_shape;
434 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
435 begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
436 end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
437 &partial_processing_shape, &partial_final_shape, is_identity,
438 is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
439
440 // Verify that the output shapes are fully known
441 if (!partial_processing_shape.AsTensorShape(processing_shape) ||
442 !partial_final_shape.AsTensorShape(final_shape)) {
443 return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
444 partial_processing_shape.DebugString(), " and ",
445 partial_final_shape.DebugString());
446 }
447 return OkStatus();
448 }
449
StridedSliceAssignBCast(const StridedSliceAssignBCast::Vec & input_shape,const StridedSliceAssignBCast::Vec & output_shape)450 StridedSliceAssignBCast::StridedSliceAssignBCast(
451 const StridedSliceAssignBCast::Vec& input_shape,
452 const StridedSliceAssignBCast::Vec& output_shape)
453 : valid_(true),
454 broadcasting_required_(false),
455 reshape_(output_shape.size()),
456 bcast_(output_shape.size()),
457 result_shape_(output_shape) {
458 // The input needs to be reshaped to have the same number of dimensions as
459 // the output. This is accomplished by either prepending with ones or removing
460 // leading, as necessary.
461 size_t input_start = 0;
462 size_t prepend_size = 0;
463 if (output_shape.size() < input_shape.size()) {
464 // Numpy allows assigning a larger rank array to smaller as long as
465 // broadcasting would otherwise work and the prefix dimensions are all 1.
466 // Though this behavior is undocumented, we allow it here for consistency.
467 // See https://github.com/numpy/numpy/issues/21744 for details.
468 input_start = input_shape.size() - output_shape.size();
469 for (size_t i = 0; i < input_start; ++i) {
470 if (input_shape[i] != 1) {
471 valid_ = false;
472 return;
473 }
474 }
475 } else {
476 prepend_size = output_shape.size() - input_shape.size();
477 }
478 std::fill_n(reshape_.begin(), prepend_size, 1);
479 std::copy(input_shape.begin() + input_start, input_shape.end(),
480 reshape_.begin() + prepend_size);
481
482 // In order to broadcast, dimensions must either be equal or one.
483 for (size_t i = 0; i < output_shape.size(); ++i) {
484 if (reshape_[i] == output_shape[i]) {
485 bcast_[i] = 1;
486 } else if (reshape_[i] == 1) {
487 bcast_[i] = output_shape[i];
488 broadcasting_required_ = true;
489 } else {
490 valid_ = false;
491 return;
492 }
493 }
494 }
495
RemapDimensions(int64_t num_dims,const StridedSliceAssignBCast::Vec & dimension_map)496 bool StridedSliceAssignBCast::RemapDimensions(
497 int64_t num_dims, const StridedSliceAssignBCast::Vec& dimension_map) {
498 // Each element in the map corresponds to the original result shape, so
499 // the sizes must be equal.
500 if (dimension_map.size() != result_shape_.size()) {
501 return false;
502 }
503
504 // Ensure all indices are within-bounds before any modifications are made -
505 // otherwise we could be left in a corrupted state.
506 for (size_t i = 0; i < dimension_map.size(); ++i) {
507 int64_t dim = dimension_map[i];
508 if (dim >= num_dims) {
509 return false;
510 }
511 }
512
513 Vec old_reshape = std::move(reshape_);
514 Vec old_bcast = std::move(bcast_);
515 Vec old_result_shape = std::move(result_shape_);
516 reshape_ = Vec(num_dims);
517 bcast_ = Vec(num_dims);
518 result_shape_ = Vec(num_dims);
519 std::fill_n(reshape_.begin(), num_dims, 1);
520 std::fill_n(bcast_.begin(), num_dims, 1);
521 std::fill_n(result_shape_.begin(), num_dims, 1);
522 for (size_t i = 0; i < dimension_map.size(); ++i) {
523 int64_t dim = dimension_map[i];
524 if (dim >= 0) {
525 reshape_[dim] = old_reshape[i];
526 bcast_[dim] = old_bcast[i];
527 result_shape_[dim] = old_result_shape[i];
528 }
529 }
530
531 return true;
532 }
533
534 } // namespace tensorflow
535