xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/strided_slice_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/strided_slice_op.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <iterator>
21 
22 #include "tensorflow/core/framework/bounds_check.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 namespace {
27 
28 /// Constants
29 constexpr int32_t kShrinkAxis = -1, kNewAxis = -2;
30 
31 // Sparse slicing specification
32 // if one does foo[3:5, ..., -3], this will have 3 length tensors
33 struct StridedSliceSparseSpec {
34   int64_t dims;
35   int32 num_add_axis_after_ellipsis;
36   const Tensor* begin_tensor;
37   const Tensor* end_tensor;
38   const Tensor& strides_tensor;
39   const int32 begin_mask, end_mask;
40   int32 ellipsis_mask;
41   const int32 new_axis_mask, shrink_axis_mask;
42 };
43 
44 // Dense slicing specification
45 // all ellipses and newaxis' are expanded out. So if
46 // foo[3:5, ..., -3] where foo is 10 dimensional,
47 // each inlinedVector will have 10 entries whereas the
48 // sparse had 3 length tensors.
49 struct StridedSliceDenseSpec {
50   const int64_t dims;
51   int32 begin_mask;
52   int32 end_mask;
53   bool begin_valid;
54   bool end_valid;
55   gtl::InlinedVector<int64_t, 4>& begin;
56   gtl::InlinedVector<int64_t, 4>& end;
57   gtl::InlinedVector<int64_t, 4>& strides;
58   // This vector helps construct the final shape of the slice.
59   // The final tensor is reduced in rank whenever a single index e.g. foo[3]
60   // is called for. The final tensor increases in rank with tf.newaxis
61   // entries. If an index in this array is positive, the size of the dimension
62   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
63   // it will be 1. A shrunk dimension is skipped.
64   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
65   // This vector has the same size as final_shape_gather_indices, but it
66   // remembers the sparse index that a dimension comes from, instead of dense
67   // index. A -1 in this vector means there the index is not from the sparse
68   // input.
69   gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
70   gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
71   // The dense indexed shrink mask is which processing dimensions
72   // should be shrunk. For example, if foo.shape = (10,10,10,10)
73   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
74   // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
75   int32 shrink_axis_mask;
76 };
77 
78 }  // namespace
79 
80 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)81 static Status TF_MUST_USE_RESULT BuildDenseSpec(
82     const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
83   // Build expanded begin, end, strides, begin_mask, end_mask
84   // to remove any ellipsis
85   dense->begin.resize(dense->dims);
86   dense->end.resize(dense->dims);
87   dense->strides.resize(dense->dims);
88   dense->input_shape_gather_indices_sparse.resize(dense->dims);
89   // What indices to get the final shape from.
90   dense->begin_mask = 0;
91   dense->end_mask = 0;
92   dense->shrink_axis_mask = 0;
93   {
94     int full_index = 0;
95 
96     const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
97     dense->begin_valid = sparse.begin_tensor != nullptr;
98     dense->end_valid = sparse.end_tensor != nullptr;
99 
100     const T* const begin_flat = sparse.begin_tensor != nullptr
101                                     ? sparse.begin_tensor->vec<T>().data()
102                                     : nullptr;
103     const T* const end_flat = sparse.end_tensor != nullptr
104                                   ? sparse.end_tensor->vec<T>().data()
105                                   : nullptr;
106 
107     for (int i = 0; i < sparse.dims; i++) {
108       if ((1 << i) & sparse.ellipsis_mask) {
109         // Expand the ellipsis into the appropriate indices
110         // NOTE: this only works because we guaranteed one ellipsis
111         int32_t next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
112                                           sparse.num_add_axis_after_ellipsis,
113                                       dense->dims);
114         for (; full_index < next_index; full_index++) {
115           // new_axis' aren't real axis so you have to skip
116           dense->begin[full_index] = dense->end[full_index] = 0;
117           dense->strides[full_index] = 1;
118           dense->begin_mask |= (1 << full_index);
119           dense->end_mask |= (1 << full_index);
120           dense->final_shape_gather_indices.push_back(full_index);
121           dense->final_shape_gather_indices_sparse.push_back(-1);
122           dense->input_shape_gather_indices_sparse[full_index] = i;
123         }
124       } else if ((1 << i) & sparse.new_axis_mask) {
125         dense->final_shape_gather_indices.push_back(kNewAxis);
126         dense->final_shape_gather_indices_sparse.push_back(-1);
127       } else {
128         if (full_index == dense->begin.size()) {
129           return errors::InvalidArgument("Index out of range using input dim ",
130                                          full_index, "; input has only ",
131                                          dense->dims, " dims");
132         }
133 
134         // Gather slicing spec into appropriate index
135         if (begin_flat != nullptr) {
136           dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
137         }
138         if (end_flat != nullptr) {
139           dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
140         }
141         dense->strides[full_index] =
142             internal::SubtleMustCopy<T>(strides_flat[i]);
143         if (sparse.begin_mask & (1 << i)) {
144           dense->begin_mask |= (1 << full_index);
145         }
146         if (sparse.end_mask & (1 << i)) {
147           dense->end_mask |= (1 << full_index);
148         }
149         // If shrink, record where to get the dimensionality from (i.e.
150         // new_axis creates a fake 1 size dimension. Also remember shrink
151         // axis (now in dense form) so we can ignore dense->end below.
152         if (sparse.shrink_axis_mask & (1 << i)) {
153           dense->final_shape_gather_indices.push_back(kShrinkAxis);
154           dense->final_shape_gather_indices_sparse.push_back(-1);
155           dense->shrink_axis_mask |= (1 << full_index);
156         } else {
157           dense->final_shape_gather_indices.push_back(full_index);
158           // Remember that where in the sparse shape the dense dim comes
159           // from.
160           dense->final_shape_gather_indices_sparse.push_back(i);
161         }
162         dense->input_shape_gather_indices_sparse[full_index] = i;
163         full_index++;
164       }
165     }
166   }
167   return OkStatus();
168 }
169 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32_t begin_mask_spec,int32_t end_mask_spec,const int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64_t,4> * begin,gtl::InlinedVector<int64_t,4> * end,gtl::InlinedVector<int64_t,4> * strides,StridedSliceShapeSpec * shape_spec)170 Status ValidateStridedSliceOp(
171     const Tensor* begin_tensor, const Tensor* end_tensor,
172     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
173     int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
174     int32_t new_axis_mask, int32_t shrink_axis_mask,
175     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
176     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
177     gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
178     gtl::InlinedVector<int64_t, 4>* strides,
179     StridedSliceShapeSpec* shape_spec) {
180   const bool begin_is_wrong =
181       begin_tensor != nullptr &&
182       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
183         begin_tensor->NumElements() == strides_tensor.NumElements() &&
184         begin_tensor->NumElements() < 32 /* using 32 bit masks */);
185   const bool end_is_wrong =
186       end_tensor != nullptr &&
187       !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
188         end_tensor->NumElements() == strides_tensor.NumElements());
189   if (begin_is_wrong || end_is_wrong ||
190       !TensorShapeUtils::IsVector(strides_tensor.shape())) {
191     if (begin_tensor != nullptr && end_tensor != nullptr) {
192       return errors::InvalidArgument(
193           "Expected begin, end, and strides to be 1D equal size tensors, ",
194           "but got shapes ", begin_tensor->shape().DebugString(), ", ",
195           end_tensor->shape().DebugString(), ", and ",
196           strides_tensor.shape().DebugString(), " instead.");
197     } else {
198       return errors::InvalidArgument(
199           "Expected begin, end, and strides to be 1D equal size tensors, ",
200           "but got shape ", strides_tensor.shape().DebugString(),
201           " for strides.");
202     }
203   }
204   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
205   // i.e. there exists only no more than one ellipsis
206   if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
207     return errors::InvalidArgument(
208         "Multiple ellipses in slice spec not allowed");
209   }
210 
211   // Step 1: Account for ellipsis and new axis
212   //
213   // Check for ellipses and count how many non-newaxis' there are after
214   // TODO(aselle): Convert this to do a fast log2 followed by iteration
215   //               counting ones in next guys
216   bool ellipsis_seen = false;
217 
218   StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
219                                         0,
220                                         begin_tensor,
221                                         end_tensor,
222                                         strides_tensor,
223                                         begin_mask_spec,
224                                         end_mask_spec,
225                                         ellipsis_mask,
226                                         new_axis_mask,
227                                         shrink_axis_mask};
228 
229   for (int32_t i = 0; i < sparse_spec.dims; i++) {
230     if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
231       sparse_spec.num_add_axis_after_ellipsis++;
232     }
233     if ((1 << i) & ellipsis_mask) {
234       ellipsis_seen = true;
235     }
236   }
237   // If no ellipsis insert one at the end
238   if (!ellipsis_seen) {
239     sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
240     sparse_spec.dims++;  // this effects loop iteration below
241   }
242 
243   // Step 2: Make a sparse spec into a full index spec
244   //
245   // The sparse spec does not correspond to the number of dimensions
246   // Make a dense spec that corresponds to the number of dimensions
247   //
248   // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
249   // we need to produce the missing begin_mask for the first two
250   // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
251   // we achieve begin_mask=6, end_mask=7
252   StridedSliceDenseSpec dense_spec = {input_shape.dims(),
253                                       0 /* begin_mask */,
254                                       0 /* end_mask */,
255                                       false /* begin_valid */,
256                                       false /* end_valid */,
257                                       *begin,
258                                       *end,
259                                       *strides};
260 
261   if (strides_tensor.dtype() == DT_INT32) {
262     TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
263   } else if (strides_tensor.dtype() == DT_INT64) {
264     TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec));
265   } else if (strides_tensor.dtype() == DT_INT16) {
266     TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec));
267   } else {
268     LOG(FATAL) << "begin must be either int16, int32 or int64";
269   }
270 
271   // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
272   //         and bounds check!
273   *is_identity = true;
274   *slice_dim0 = true;
275   *is_simple_slice = true;
276   processing_shape->Clear();
277   for (int i = 0; i < input_shape.dims(); ++i) {
278     int64_t& begin_i = (*begin)[i];
279     int64_t& end_i = (*end)[i];
280     int64_t& stride_i = (*strides)[i];
281     int64_t dim_i = input_shape.dim_size(i);
282     if (stride_i == 0) {
283       return errors::InvalidArgument("strides[", i, "] must be non-zero");
284     }
285     bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
286     if (dim_i == -1) {
287       processing_shape->AddDim(shrink_i ? 1 : -1);
288       continue;
289     }
290 
291     const std::array<int64_t, 2> masks = {
292         {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
293     const std::array<int64_t, 2> valid_range = {
294         {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
295 
296     auto canonical = [stride_i, dim_i, masks, valid_range](int64_t x, int c) {
297       if (masks[c]) {
298         return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
299       } else {
300         int64_t x_fwd =
301             x < 0 ? dim_i + x : x;  // make negative indices positive
302         return x_fwd < valid_range[0]
303                    ? valid_range[0]
304                    : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
305       }
306     };
307     if (shrink_i && stride_i <= 0) {
308       return errors::InvalidArgument(
309           "only stride 1 allowed on non-range indexing.");
310     }
311     (*is_simple_slice) &= stride_i == 1;
312 
313     const bool begin_and_end_masked =
314         (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
315     if (dense_spec.begin_valid && dense_spec.end_valid) {
316       if (shrink_i) {
317         // If we are shrinking, the end index is now possibly incorrect. In
318         // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
319         // and canonical puts these to n-1 and 0, which implies a degenerate
320         // interval. Fortunately, it is now safe to re-create end as begin+1.
321         int64_t x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
322         begin_i = x_fwd;
323         end_i = begin_i + 1;
324         if (x_fwd < 0 || x_fwd >= dim_i) {
325           return errors::InvalidArgument(
326               "slice index ", begin_i, " of dimension ", i, " out of bounds.");
327         }
328       } else {
329         begin_i = canonical(begin_i, 0);
330         end_i = canonical(end_i, 1);
331       }
332       // Update optimization values
333       bool take_all_in_dimension =
334           stride_i == 1 && begin_i == 0 && end_i == dim_i;
335       (*is_identity) &= take_all_in_dimension;
336       (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
337     } else {
338       (*is_identity) &= stride_i == 1 && begin_and_end_masked;
339       (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
340     }
341     // Compute the processing shape (the intermediate Eigen will produce)
342     int64_t interval_length;
343     bool known_interval = false;
344     if (dense_spec.begin_valid && dense_spec.end_valid) {
345       interval_length = end_i - begin_i;
346       known_interval = true;
347     } else if (shrink_i) {
348       // The dimension is still known as 1 for the processing_shape, but will be
349       // discarded for the final shape.
350       interval_length = 1;
351       known_interval = true;
352     } else if (begin_and_end_masked) {
353       // Even if we don't have values for begin or end, we do know that this
354       // dimension covers the whole interval. If we have shape information for
355       // this dimension, that tells us the interval length.
356       if (dim_i >= 0) {
357         if (stride_i < 0) {
358           interval_length = -dim_i;
359         } else {
360           interval_length = dim_i;
361         }
362         known_interval = true;
363       }
364     }
365     if (known_interval) {
366       int64_t size_i;
367       // Hold zero if the interval is degenerate, otherwise account for
368       // remainder
369       if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
370         size_i = 0;
371       } else {
372         size_i = interval_length / stride_i +
373                  (interval_length % stride_i != 0 ? 1 : 0);
374       }
375       processing_shape->AddDim(size_i);
376     } else {
377       processing_shape->AddDim(-1);
378     }
379   }
380 
381   // Step 4: Compute the final shape
382   //
383   // new_axis will increase dimension by 1 (with a one-size dimension)
384   // slices like foo[3,...] will reduce dimension by 1.
385   // This cannot be done earlier, because it depends on Step 3.
386   final_shape->Clear();
387   if (shape_spec != nullptr) {
388     shape_spec->output_to_sparse_mapping.clear();
389     shape_spec->output_to_processing_mapping.clear();
390     shape_spec->processing_to_sparse_mapping.assign(
391         dense_spec.input_shape_gather_indices_sparse.begin(),
392         dense_spec.input_shape_gather_indices_sparse.end());
393 
394     shape_spec->begin_dense_mask = dense_spec.begin_mask;
395     shape_spec->end_dense_mask = dense_spec.end_mask;
396     shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
397   }
398 
399   for (int64_t dense_dim = 0;
400        dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
401     int64_t gather_index = dense_spec.final_shape_gather_indices[dense_dim];
402     int64_t sparse_index =
403         dense_spec.final_shape_gather_indices_sparse[dense_dim];
404     if (gather_index >= 0) {
405       final_shape->AddDim(processing_shape->dim_size(gather_index));
406       if (shape_spec != nullptr) {
407         shape_spec->output_to_sparse_mapping.push_back(sparse_index);
408         shape_spec->output_to_processing_mapping.push_back(gather_index);
409       }
410     } else if (gather_index == kNewAxis) {
411       final_shape->AddDim(1);
412       if (shape_spec != nullptr) {
413         shape_spec->output_to_sparse_mapping.push_back(-1);
414         shape_spec->output_to_processing_mapping.push_back(-1);
415       }
416     }
417   }
418 
419   return OkStatus();
420 }
421 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32_t begin_mask_spec,int32_t end_mask_spec,const int32_t ellipsis_mask,int32_t new_axis_mask,int32_t shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64_t,4> * begin,gtl::InlinedVector<int64_t,4> * end,gtl::InlinedVector<int64_t,4> * strides,StridedSliceShapeSpec * shape_spec)422 Status ValidateStridedSliceOp(
423     const Tensor* begin_tensor, const Tensor* end_tensor,
424     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
425     int32_t begin_mask_spec, int32_t end_mask_spec, const int32_t ellipsis_mask,
426     int32_t new_axis_mask, int32_t shrink_axis_mask,
427     TensorShape* processing_shape, TensorShape* final_shape, bool* is_identity,
428     bool* is_simple_slice, bool* slice_dim0,
429     gtl::InlinedVector<int64_t, 4>* begin, gtl::InlinedVector<int64_t, 4>* end,
430     gtl::InlinedVector<int64_t, 4>* strides,
431     StridedSliceShapeSpec* shape_spec) {
432   // Validate with PartialTensorShape output
433   PartialTensorShape partial_processing_shape, partial_final_shape;
434   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
435       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
436       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
437       &partial_processing_shape, &partial_final_shape, is_identity,
438       is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
439 
440   // Verify that the output shapes are fully known
441   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
442       !partial_final_shape.AsTensorShape(final_shape)) {
443     return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
444                             partial_processing_shape.DebugString(), " and ",
445                             partial_final_shape.DebugString());
446   }
447   return OkStatus();
448 }
449 
StridedSliceAssignBCast(const StridedSliceAssignBCast::Vec & input_shape,const StridedSliceAssignBCast::Vec & output_shape)450 StridedSliceAssignBCast::StridedSliceAssignBCast(
451     const StridedSliceAssignBCast::Vec& input_shape,
452     const StridedSliceAssignBCast::Vec& output_shape)
453     : valid_(true),
454       broadcasting_required_(false),
455       reshape_(output_shape.size()),
456       bcast_(output_shape.size()),
457       result_shape_(output_shape) {
458   // The input needs to be reshaped to have the same number of dimensions as
459   // the output. This is accomplished by either prepending with ones or removing
460   // leading, as necessary.
461   size_t input_start = 0;
462   size_t prepend_size = 0;
463   if (output_shape.size() < input_shape.size()) {
464     // Numpy allows assigning a larger rank array to smaller as long as
465     // broadcasting would otherwise work and the prefix dimensions are all 1.
466     // Though this behavior is undocumented, we allow it here for consistency.
467     // See https://github.com/numpy/numpy/issues/21744 for details.
468     input_start = input_shape.size() - output_shape.size();
469     for (size_t i = 0; i < input_start; ++i) {
470       if (input_shape[i] != 1) {
471         valid_ = false;
472         return;
473       }
474     }
475   } else {
476     prepend_size = output_shape.size() - input_shape.size();
477   }
478   std::fill_n(reshape_.begin(), prepend_size, 1);
479   std::copy(input_shape.begin() + input_start, input_shape.end(),
480             reshape_.begin() + prepend_size);
481 
482   // In order to broadcast, dimensions must either be equal or one.
483   for (size_t i = 0; i < output_shape.size(); ++i) {
484     if (reshape_[i] == output_shape[i]) {
485       bcast_[i] = 1;
486     } else if (reshape_[i] == 1) {
487       bcast_[i] = output_shape[i];
488       broadcasting_required_ = true;
489     } else {
490       valid_ = false;
491       return;
492     }
493   }
494 }
495 
RemapDimensions(int64_t num_dims,const StridedSliceAssignBCast::Vec & dimension_map)496 bool StridedSliceAssignBCast::RemapDimensions(
497     int64_t num_dims, const StridedSliceAssignBCast::Vec& dimension_map) {
498   // Each element in the map corresponds to the original result shape, so
499   // the sizes must be equal.
500   if (dimension_map.size() != result_shape_.size()) {
501     return false;
502   }
503 
504   // Ensure all indices are within-bounds before any modifications are made -
505   // otherwise we could be left in a corrupted state.
506   for (size_t i = 0; i < dimension_map.size(); ++i) {
507     int64_t dim = dimension_map[i];
508     if (dim >= num_dims) {
509       return false;
510     }
511   }
512 
513   Vec old_reshape = std::move(reshape_);
514   Vec old_bcast = std::move(bcast_);
515   Vec old_result_shape = std::move(result_shape_);
516   reshape_ = Vec(num_dims);
517   bcast_ = Vec(num_dims);
518   result_shape_ = Vec(num_dims);
519   std::fill_n(reshape_.begin(), num_dims, 1);
520   std::fill_n(bcast_.begin(), num_dims, 1);
521   std::fill_n(result_shape_.begin(), num_dims, 1);
522   for (size_t i = 0; i < dimension_map.size(); ++i) {
523     int64_t dim = dimension_map[i];
524     if (dim >= 0) {
525       reshape_[dim] = old_reshape[i];
526       bcast_[dim] = old_bcast[i];
527       result_shape_[dim] = old_result_shape[i];
528     }
529   }
530 
531   return true;
532 }
533 
534 }  // namespace tensorflow
535