xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_shape.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/util/overflow.h"
26 
27 namespace tensorflow {
28 
29 // TensorShape and PartialTensorShape should have no fields beyond
30 // TensorShapeRep.  In particular, their sizes should be the same.
31 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
32               "TensorShape must have no fields beyond TensorShapeRep");
33 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
34               "PartialTensorShape must have no fields beyond TensorShapeRep");
35 
36 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)37 static void AppendTo(const TensorShapeBase<Shape>& s,
38                      gtl::InlinedVector<int64, 8>* vals) {
39   for (auto dim : s) {
40     vals->push_back(dim.size);
41   }
42 }
43 
CheckDimsEqual(int NDIMS) const44 void TensorShape::CheckDimsEqual(int NDIMS) const {
45   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
46                           << " from a tensor of " << dims() << " dimensions";
47 }
48 
CheckDimsAtMost(int NDIMS) const49 void TensorShape::CheckDimsAtMost(int NDIMS) const {
50   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at most " << NDIMS
51                           << " dimensions from a tensor of " << dims()
52                           << " dimensions";
53 }
54 
55 // TODO(slebedev): Consider merging IsValid implementations.
56 template <class Shape>
IsValid()57 bool TensorShapeBase<Shape>::IsValid() {
58   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
59   // unknown_shape() set, and it seems hard to remove this without backwards
60   // compatibility issues.
61   if (kIsPartial && unknown_rank()) return dims() == 0;
62   int64_t num_elements = 1;
63   if (dims() > MaxDimensions()) return false;
64   for (auto d : dim_sizes()) {
65     if (d < (kIsPartial ? -1 : 0)) return false;
66     if (d == -1) {
67       num_elements = -1;
68     } else if (!kIsPartial || num_elements >= 0) {
69       num_elements = MultiplyWithoutOverflow(num_elements, d);
70       if (num_elements < 0) return false;
71     }
72   }
73   return true;
74 }
75 
76 template <class Shape>
IsValid(const TensorShapeProto & proto)77 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
78   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
79   // unknown_shape() set, and it seems hard to remove this without backwards
80   // compatibility issues.
81   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
82   int64_t num_elements = 1;
83   if (proto.dim().size() > MaxDimensions()) return false;
84   for (const auto& d : proto.dim()) {
85     if (d.size() < (kIsPartial ? -1 : 0)) return false;
86     if (d.size() == -1) {
87       num_elements = -1;
88     } else if (!kIsPartial || num_elements >= 0) {
89       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
90       if (num_elements < 0) return false;
91     }
92   }
93   return true;
94 }
95 
96 template <class Shape>
IsValidShape(const TensorShapeProto & proto)97 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
98   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
99   // unknown_shape() set, and it seems hard to remove this without backwards
100   // compatibility issues.
101   if (kIsPartial && proto.unknown_rank()) {
102     if (proto.dim_size() > 0) {
103       return errors::InvalidArgument(
104           "An unknown shape must not have any dimensions set.");
105     }
106     return OkStatus();
107   }
108   int64_t num_elements = 1;
109   if (proto.dim().size() > MaxDimensions()) {
110     return errors::InvalidArgument("Shape ", DebugString(proto),
111                                    " has too many dimensions");
112   }
113   for (const auto& d : proto.dim()) {
114     if (d.size() < (kIsPartial ? -1 : 0)) {
115       if (kIsPartial) {
116         return errors::InvalidArgument(
117             "Shape ", DebugString(proto),
118             " has dimensions with values below -1 (where -1 means unknown)");
119       } else {
120         return errors::InvalidArgument("Shape ", DebugString(proto),
121                                        " is not fully defined");
122       }
123     }
124     if (d.size() == -1) {
125       num_elements = -1;
126     } else if (!kIsPartial || num_elements >= 0) {
127       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
128       if (num_elements < 0) {
129         return errors::InvalidArgument(
130             "Shape ", DebugString(proto),
131             " is too large (more than 2**63 - 1 entries)");
132       }
133     }
134   }
135   return OkStatus();
136 }
137 
138 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)139 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
140   set_tag(REP16);
141   set_data_type(DT_INVALID);
142   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
143   // unknown_shape() set, and it seems hard to remove this without backwards
144   // compatibility issues.
145   if (kIsPartial && proto.unknown_rank()) {
146     set_ndims_byte(kUnknownRank);
147     set_num_elements(-1);
148   } else {
149     set_ndims_byte(0);
150     set_num_elements(1);
151     for (const auto& d : proto.dim()) {
152       AddDim(d.size());
153     }
154   }
155 }
156 
157 template <class Shape>
BuildTensorShapeBase(const TensorShapeProto & proto,TensorShapeBase * out)158 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
159     const TensorShapeProto& proto, TensorShapeBase* out) {
160   out->set_tag(REP16);
161   out->set_data_type(DT_INVALID);
162   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
163   // unknown_shape() set, and it seems hard to remove this without backwards
164   // compatibility issues.
165   if (kIsPartial && proto.unknown_rank()) {
166     out->set_ndims_byte(kUnknownRank);
167     out->set_num_elements(-1);
168   } else {
169     out->set_ndims_byte(0);
170     out->set_num_elements(1);
171     Status s = OkStatus();
172     for (const auto& d : proto.dim()) {
173       s = out->AddDimWithStatus(d.size());
174       if (!s.ok()) {
175         return s;
176       }
177     }
178   }
179   return OkStatus();
180 }
181 
182 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes)183 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes) {
184   set_tag(REP16);
185   set_data_type(DT_INVALID);
186   TF_CHECK_OK(InitDims(dim_sizes));
187 }
188 
189 template <class Shape>
BuildTensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes,TensorShapeBase * out)190 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
191     gtl::ArraySlice<int64_t> dim_sizes, TensorShapeBase* out) {
192   out->set_tag(REP16);
193   out->set_data_type(DT_INVALID);
194   return out->InitDims(dim_sizes);
195 }
196 
197 // Returns true iff partial is true and val is < 0.
198 // REQUIRES: val < kMaxRep16
199 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64_t val)200 static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) {
201   if (partial) {
202     if (val < 0) {
203       dst[dim] = std::numeric_limits<uint16>::max();
204       return true;
205     }
206   }
207   dst[dim] = val;
208   return false;
209 }
210 
211 template <class Shape>
InitDims(gtl::ArraySlice<int64_t> dim_sizes)212 Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64_t> dim_sizes) {
213   DCHECK_EQ(tag(), REP16);
214 
215   // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
216   // below cannot overflow.
217   static const int64_t kMaxSmall = 0xd744;
218   static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
219                 "bad overflow check");
220   bool large_size = false;
221   for (auto s : dim_sizes) {
222     if (s > kMaxSmall) {
223       large_size = true;
224       break;
225     }
226   }
227 
228   if (!kIsPartial && !large_size) {
229     for (auto s : dim_sizes) {
230       if (TF_PREDICT_FALSE(s < 0)) {
231         return errors::InvalidArgument(
232             "Expected shape dimensions to be non-negative, got ", s);
233       }
234     }
235   }
236 
237   if (!large_size) {
238     // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
239     uint16* dst = as16()->dims_;
240     switch (dim_sizes.size()) {
241       case 1: {
242         set_ndims_byte(1);
243         const int64_t size = dim_sizes[0];
244         const bool neg = Set16(kIsPartial, dst, 0, size);
245         set_num_elements(neg ? -1 : size);
246         return OkStatus();
247       }
248       case 2: {
249         set_ndims_byte(2);
250         const int64_t size0 = dim_sizes[0];
251         const int64_t size1 = dim_sizes[1];
252         bool neg = Set16(kIsPartial, dst, 0, size0);
253         neg |= Set16(kIsPartial, dst, 1, size1);
254         set_num_elements(neg ? -1 : (size0 * size1));
255         return OkStatus();
256       }
257       case 3: {
258         set_ndims_byte(3);
259         const int64_t size0 = dim_sizes[0];
260         const int64_t size1 = dim_sizes[1];
261         const int64_t size2 = dim_sizes[2];
262         bool neg = Set16(kIsPartial, dst, 0, size0);
263         neg |= Set16(kIsPartial, dst, 1, size1);
264         neg |= Set16(kIsPartial, dst, 2, size2);
265         set_num_elements(neg ? -1 : (size0 * size1 * size2));
266         return OkStatus();
267       }
268       case 4: {
269         set_ndims_byte(4);
270         const int64_t size0 = dim_sizes[0];
271         const int64_t size1 = dim_sizes[1];
272         const int64_t size2 = dim_sizes[2];
273         const int64_t size3 = dim_sizes[3];
274         bool neg = Set16(kIsPartial, dst, 0, size0);
275         neg |= Set16(kIsPartial, dst, 1, size1);
276         neg |= Set16(kIsPartial, dst, 2, size2);
277         neg |= Set16(kIsPartial, dst, 3, size3);
278         set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
279         return OkStatus();
280       }
281     }
282   }
283 
284   set_ndims_byte(0);
285   set_num_elements(1);
286   Status status = OkStatus();
287   for (int64_t s : dim_sizes) {
288     status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
289     if (!status.ok()) {
290       return status;
291     }
292   }
293 
294   return status;
295 }
296 
297 template <class Shape>
TensorShapeBase()298 TensorShapeBase<Shape>::TensorShapeBase() {
299   set_tag(REP16);
300   set_data_type(DT_INVALID);
301   if (kIsPartial) {
302     set_ndims_byte(kUnknownRank);
303     set_num_elements(-1);
304   } else {
305     set_ndims_byte(0);
306     set_num_elements(1);
307   }
308 }
309 
DestructorOutOfLine()310 void TensorShapeRep::DestructorOutOfLine() {
311   DCHECK(tag() == REP_OUT_OF_LINE);
312   delete as64()->dims_;
313 }
314 
SlowCopyFrom(const TensorShapeRep & b)315 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
316   if (b.tag() != REP_OUT_OF_LINE) {
317     if (tag() == REP_OUT_OF_LINE) {
318       delete as64()->dims_;
319     }
320     memcpy(buf(), b.buf(), sizeof(u_.buf));
321     // memcpy above implicitly also does:
322     //   set_tag(b.tag());
323     //   set_ndims_byte(b.ndims_byte());
324     //   set_data_type(b.data_type());
325   } else {
326     set_ndims_byte(b.ndims_byte());
327     set_data_type(b.data_type());
328     if (tag() == REP_OUT_OF_LINE) {
329       // vector already allocated
330       *(as64()->dims_) = *(b.as64()->dims_);
331     } else {
332       set_tag(REP_OUT_OF_LINE);
333       as64()->dims_ = new gtl::InlinedVector<int64_t, 4>(*(b.as64()->dims_));
334     }
335   }
336 }
337 
338 template <class Shape>
dim_size(int d) const339 int64_t TensorShapeBase<Shape>::dim_size(int d) const {
340   if (unknown_rank()) return -1;
341   DCHECK_GE(d, 0);
342   DCHECK_LT(d, dims());
343   if (tag() == REP16) {
344     uint16 dim = as16()->dims_[d];
345     if (kIsPartial && dim == kUnknownRep16) return -1;
346     return dim;
347   } else if (tag() == REP32) {
348     uint32 dim = as32()->dims_[d];
349     if (kIsPartial && dim == kUnknownRep32) return -1;
350     return dim;
351   } else {
352     return (*as64()->dims_)[d];
353   }
354 }
355 
Clear()356 void TensorShapeRep::Clear() {
357   ClearAllButDataType();
358   set_data_type(DT_INVALID);
359 }
360 
ClearAllButDataType()361 void TensorShapeRep::ClearAllButDataType() {
362   if (tag() == REP_OUT_OF_LINE) {
363     delete as64()->dims_;
364   }
365   set_tag(REP16);
366   set_ndims_byte(0);
367   // Leaves data_type alone
368   set_num_elements(1);
369 }
370 
371 template <class Shape>
RecomputeNumElements()372 Status TensorShapeBase<Shape>::RecomputeNumElements() {
373   if (unknown_rank()) {
374     set_num_elements(-1);
375     return OkStatus();
376   }
377   int64_t n = 1;
378   for (auto dim : *this) {
379     if (kIsPartial && dim.size < 0) {
380       n = -1;
381       break;
382     }
383     n = MultiplyWithoutOverflow(n, dim.size);
384     if (TF_PREDICT_FALSE(n < 0)) {
385       return errors::InvalidArgument(
386           "Shape ", this->DebugString(),
387           " results in overflow when computing number of elements");
388     }
389   }
390   set_num_elements(n);
391   return OkStatus();
392 }
393 
394 template <class Shape>
AddDim(int64_t size)395 void TensorShapeBase<Shape>::AddDim(int64_t size) {
396   if (!kIsPartial) CHECK_GE(size, 0);
397   if (unknown_rank()) return;
398   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
399   int64_t new_num_elements;
400   if (kIsPartial && (num_elements() < 0 || size < 0)) {
401     new_num_elements = -1;
402   } else {
403     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
404     CHECK_LE(0, new_num_elements);
405   }
406   UnsafeAddDim(size, new_num_elements);
407 }
408 
409 template <class Shape>
AddDimWithStatus(int64_t size)410 Status TensorShapeBase<Shape>::AddDimWithStatus(int64_t size) {
411   if (!kIsPartial) {
412     if (TF_PREDICT_FALSE(size < 0)) {
413       return errors::InvalidArgument("Expected a non-negative size, got ",
414                                      size);
415     }
416   }
417 
418   if (unknown_rank()) {
419     return OkStatus();
420   }
421 
422   if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423     return errors::InvalidArgument("Too many dimensions in tensor");
424   }
425 
426   int64_t new_num_elements;
427   if (kIsPartial && (num_elements() < 0 || size < 0)) {
428     new_num_elements = -1;
429   } else {
430     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431     if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432       return errors::InvalidArgument("Encountered overflow when multiplying ",
433                                      num_elements(), " with ", size,
434                                      ", result: ", new_num_elements);
435     }
436   }
437 
438   UnsafeAddDim(size, new_num_elements);
439   return OkStatus();
440 }
441 
442 template <class Shape>
UnsafeAddDim(int64_t size,int64_t new_num_elements)443 void TensorShapeBase<Shape>::UnsafeAddDim(int64_t size,
444                                           int64_t new_num_elements) {
445   const int nd = ndims_byte();
446   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
447     as16()->dims_[nd] =
448         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
449   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
450     as32()->dims_[nd] =
451         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
452   } else if (tag() == REP_OUT_OF_LINE) {
453     as64()->dims_->push_back(size);
454   } else {
455     // Need to change representation
456     gtl::InlinedVector<int64_t, 8> vals;
457     AppendTo(*this, &vals);
458     vals.push_back(size);
459     // We know we can't be REP16.  See if we have a small enough
460     // number of dimensions and each dimension's size is small enough
461     // to allow REP32.
462     bool can_be_rep32 = (vals.size() <= 3);
463     if (can_be_rep32) {
464       for (size_t i = 0; i < vals.size(); i++) {
465         if (vals[i] >= kMaxRep32) {
466           can_be_rep32 = false;
467           break;
468         }
469       }
470     }
471     if (can_be_rep32) {
472       set_tag(REP32);
473       for (size_t d = 0; d < vals.size(); d++) {
474         as32()->dims_[d] = kIsPartial && vals[d] < 0
475                                ? kUnknownRep32
476                                : static_cast<uint32>(vals[d]);
477       }
478     } else {
479       set_tag(REP_OUT_OF_LINE);
480       as64()->dims_ =
481           new gtl::InlinedVector<int64_t, 4>(vals.begin(), vals.end());
482     }
483   }
484   set_ndims_byte(nd + 1);
485   set_num_elements(new_num_elements);
486 }
487 
488 template <class Shape>
AppendShape(const TensorShapeBase & shape)489 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
490   for (auto d : shape) AddDim(d.size);
491 }
492 
493 template <class Shape>
AppendShapeWithStatus(const TensorShapeBase & shape)494 Status TensorShapeBase<Shape>::AppendShapeWithStatus(
495     const TensorShapeBase& shape) {
496   Status s = OkStatus();
497   for (auto d : shape) {
498     s.Update(AddDimWithStatus(d.size));
499     if (!s.ok()) {
500       return s;
501     }
502   }
503   return s;
504 }
505 
506 template <class Shape>
InsertDim(int d,int64_t size)507 void TensorShapeBase<Shape>::InsertDim(int d, int64_t size) {
508   CHECK_GE(d, 0);
509   CHECK_LE(d, dims());
510   if (!kIsPartial) CHECK_GE(size, 0);
511   CHECK_LT(dims(), MaxDimensions());
512   gtl::InlinedVector<int64_t, 8> vals;
513   AppendTo(*this, &vals);
514   vals.insert(vals.begin() + d, size);
515   ClearAllButDataType();
516   for (auto dval : vals) {
517     AddDim(dval);
518   }
519 }
520 
521 template <class Shape>
InsertDimWithStatus(int d,int64_t size)522 Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64_t size) {
523   if (!kIsPartial) {
524     if (TF_PREDICT_FALSE(size < 0)) {
525       return errors::InvalidArgument("Expected a non-negative size, got ",
526                                      size);
527     }
528   }
529 
530   if (TF_PREDICT_FALSE(d < 0)) {
531     return errors::Internal("The insertion index must be non-negative, got ",
532                             d);
533   }
534   if (TF_PREDICT_FALSE(d > dims())) {
535     return errors::Internal("The insertion index must be at most ", dims(),
536                             " got ", d);
537   }
538   if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
539     return errors::Internal("Shape has ", dims(),
540                             " dimensions which is the maximum allowed");
541   }
542 
543   gtl::InlinedVector<int64_t, 8> vals;
544   AppendTo(*this, &vals);
545   vals.insert(vals.begin() + d, size);
546   ClearAllButDataType();
547 
548   Status s = OkStatus();
549   for (auto dval : vals) {
550     s.Update(AddDimWithStatus(dval));
551     if (!s.ok()) {
552       return s;
553     }
554   }
555   return s;
556 }
557 
558 template <class Shape>
dim_sizes() const559 gtl::InlinedVector<int64_t, 4> TensorShapeBase<Shape>::dim_sizes() const {
560   gtl::InlinedVector<int64_t, 4> result;
561   for (auto dim : *this) {
562     result.push_back(dim.size);
563   }
564   return result;
565 }
566 
567 template <class Shape>
set_dim(int d,int64_t size)568 void TensorShapeBase<Shape>::set_dim(int d, int64_t size) {
569   CHECK_GE(d, 0);
570   CHECK_LT(d, dims());
571   if (!kIsPartial) {
572     CHECK_GE(size, 0);
573   }
574   if (tag() == REP16 && size < kMaxRep16) {
575     as16()->dims_[d] =
576         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
577   } else if (tag() == REP32 && size < kMaxRep32) {
578     as32()->dims_[d] =
579         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
580   } else if (tag() == REP_OUT_OF_LINE) {
581     (*as64()->dims_)[d] = size;
582   } else {
583     // Must upgrade
584     gtl::InlinedVector<int64_t, 8> vals;
585     AppendTo(*this, &vals);
586     vals[d] = size;
587     ClearAllButDataType();
588     for (auto dval : vals) {
589       AddDim(dval);
590     }
591   }
592   TF_CHECK_OK(RecomputeNumElements());
593 }
594 
595 template <class Shape>
SetDimWithStatus(int d,int64_t size)596 Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64_t size) {
597   if (TF_PREDICT_FALSE(d < 0)) {
598     return errors::InvalidArgument("Index must be non-negative, got ", d);
599   }
600   if (TF_PREDICT_FALSE(d >= dims())) {
601     return errors::InvalidArgument("Index must be less than ", dims(), ", got ",
602                                    d);
603   }
604   if (TF_PREDICT_FALSE(!kIsPartial && size < 0)) {
605     return errors::InvalidArgument("Expected a non-negative size, got ", size);
606   }
607 
608   if (tag() == REP16 && size < kMaxRep16) {
609     as16()->dims_[d] =
610         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
611   } else if (tag() == REP32 && size < kMaxRep32) {
612     as32()->dims_[d] =
613         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
614   } else if (tag() == REP_OUT_OF_LINE) {
615     (*as64()->dims_)[d] = size;
616   } else {
617     // Must upgrade
618     gtl::InlinedVector<int64_t, 8> vals;
619     AppendTo(*this, &vals);
620     vals[d] = size;
621     ClearAllButDataType();
622 
623     Status s = OkStatus();
624     for (auto dval : vals) {
625       s.Update(AddDimWithStatus(dval));
626       if (!s.ok()) {
627         return s;
628       }
629     }
630   }
631 
632   return RecomputeNumElements();
633 }
634 
635 template <class Shape>
RemoveDimRange(int begin,int end)636 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
637   if (unknown_rank()) return;
638   begin = begin < 0 ? dims() + begin + 1 : begin;
639   end = end < 0 ? dims() + end + 1 : end;
640   CHECK_GE(begin, 0);
641   CHECK_LE(begin, dims());
642   CHECK_GE(end, 0);
643   CHECK_LE(end, dims());
644   if (begin >= end) return;
645   gtl::InlinedVector<int64_t, 8> vals;
646   AppendTo(*this, &vals);
647   vals.erase(vals.begin() + begin, vals.begin() + end);
648   ClearAllButDataType();
649   for (auto dval : vals) {
650     AddDim(dval);
651   }
652   TF_CHECK_OK(RecomputeNumElements());
653 }
654 
655 template <class Shape>
RemoveDimRangeWithStatus(int begin,int end)656 Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
657   if (unknown_rank()) {
658     return OkStatus();
659   }
660 
661   begin = begin < 0 ? dims() + begin + 1 : begin;
662   end = end < 0 ? dims() + end + 1 : end;
663 
664   if (TF_PREDICT_FALSE(begin < 0)) {
665     return errors::Internal("Start index must be non-negative, got ", begin);
666   }
667   if (TF_PREDICT_FALSE(begin > dims())) {
668     return errors::Internal("Start index must be less than ", dims(), ", got ",
669                             begin);
670   }
671   if (TF_PREDICT_FALSE(end < 0)) {
672     return errors::Internal("End index must be non-negative, got ", end);
673   }
674   if (TF_PREDICT_FALSE(end > dims())) {
675     return errors::Internal("End index must be less than ", dims(), ", got ",
676                             end);
677   }
678 
679   if (begin >= end) {
680     return OkStatus();
681   }
682 
683   gtl::InlinedVector<int64_t, 8> vals;
684   AppendTo(*this, &vals);
685   vals.erase(vals.begin() + begin, vals.begin() + end);
686   ClearAllButDataType();
687 
688   Status s = OkStatus();
689   for (auto dval : vals) {
690     s.Update(AddDimWithStatus(dval));
691     if (!s.ok()) {
692       return s;
693     }
694   }
695 
696   return RecomputeNumElements();
697 }
698 
IsSameSize(const TensorShape & b) const699 bool TensorShape::IsSameSize(const TensorShape& b) const {
700   if (b.dims() != dims()) return false;
701   for (int d = 0; d < dims(); d++) {
702     if (dim_size(d) != b.dim_size(d)) return false;
703   }
704   return true;
705 }
706 
707 template <class Shape>
AsProto(TensorShapeProto * proto) const708 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
709   proto->Clear();
710   if (unknown_rank()) {
711     proto->set_unknown_rank(true);
712   } else {
713     for (int i = 0; i < dims(); i++) {
714       proto->add_dim()->set_size(dim_size(i));
715     }
716   }
717 }
718 
719 template <class Shape>
AsProto() const720 TensorShapeProto TensorShapeBase<Shape>::AsProto() const {
721   TensorShapeProto out;
722   AsProto(&out);
723   return out;
724 }
725 
726 template <class Shape>
begin() const727 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
728   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
729 }
730 
731 template <class Shape>
end() const732 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
733   const int max_dim = unknown_rank() ? -1 : dims();
734   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
735 }
736 
DebugString() const737 string TensorShapeRep::DebugString() const {
738   const auto& shape = *static_cast<const PartialTensorShape*>(this);
739   if (shape.unknown_rank()) return "<unknown>";
740   string s = "[";
741   for (int i = 0; i < shape.dims(); i++) {
742     if (i > 0) strings::StrAppend(&s, ",");
743     int64_t dim = shape.dim_size(i);
744     if (dim < 0) {
745       strings::StrAppend(&s, "?");
746     } else {
747       strings::StrAppend(&s, dim);
748     }
749   }
750   strings::StrAppend(&s, "]");
751   return s;
752 }
753 
DebugString(const TensorShapeProto & proto)754 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
755   string s;
756   if (proto.unknown_rank()) {
757     strings::StrAppend(&s, "<unknown>");
758     if (proto.dim_size() == 0) return s;
759   }
760   strings::StrAppend(&s, "[");
761   bool first = true;
762   for (const auto& d : proto.dim()) {
763     if (!first) strings::StrAppend(&s, ",");
764     if (d.size() == -1) {
765       strings::StrAppend(&s, "?");
766     } else {
767       strings::StrAppend(&s, d.size());
768     }
769     first = false;
770   }
771   strings::StrAppend(&s, "]");
772   return s;
773 }
774 
StartsWith(const TensorShape & shape,const TensorShape & prefix)775 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
776                                   const TensorShape& prefix) {
777   if (shape.dims() < prefix.dims()) return false;
778   for (int i = 0; i < prefix.dims(); ++i) {
779     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
780   }
781   return true;
782 }
783 
EndsWith(const TensorShape & shape,const TensorShape & suffix)784 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
785                                 const TensorShape& suffix) {
786   const int suffix_size = suffix.dims();
787   if (shape.dims() < suffix_size) return false;
788   for (int i = 0; i < suffix_size; ++i) {
789     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
790       return false;
791     }
792   }
793   return true;
794 }
795 
796 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64_t n,Shape * out)797 Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) {
798   out->Clear();
799   if (n > TensorShape::MaxDimensions()) {
800     return errors::InvalidArgument("Too many dimensions");
801   }
802   if (n < 0) {
803     return errors::InvalidArgument("Negative number of dimensions ", n);
804   }
805   for (int64_t i = 0; i < n; ++i) {
806     T dim = internal::SubtleMustCopy(dims[i]);
807     int64_t new_num_elements;
808     if (dim < 0) {
809       if (!out->kIsPartial) {
810         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
811       }
812       if (dim < -1) {
813         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
814       }
815       dim = -1;
816       new_num_elements = -1;
817     } else if (out->num_elements() < 0) {
818       new_num_elements = -1;
819     } else {
820       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
821       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
822         TensorShapeProto proto;
823         for (int64_t j = 0; j < n; ++j) {
824           proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
825         }
826         return errors::InvalidArgument(
827             "Shape ", TensorShape::DebugString(proto),
828             " would have more than 2**63 - 1 elements");
829       }
830     }
831     out->UnsafeAddDim(dim, new_num_elements);
832   }
833   return OkStatus();
834 }
835 
836 #define MAKE_SHAPE(T, Shape)                                                 \
837   Status TensorShapeUtils::MakeShape(const T* dims, int64_t n, Shape* out) { \
838     return MakeShapeHelper(dims, n, out);                                    \
839   }                                                                          \
840   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
841     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
842   }
MAKE_SHAPE(int32,TensorShape)843 MAKE_SHAPE(int32, TensorShape)
844 MAKE_SHAPE(int64_t, TensorShape)
845 MAKE_SHAPE(int32, PartialTensorShape)
846 MAKE_SHAPE(int64_t, PartialTensorShape)
847 #undef MAKE_SHAPE
848 
849 string TensorShapeUtils::ShapeListString(
850     const gtl::ArraySlice<TensorShape>& shapes) {
851   string result = "[";
852   bool first = true;
853   for (const TensorShape& shape : shapes) {
854     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
855     first = false;
856   }
857   strings::StrAppend(&result, "]");
858   return result;
859 }
860 
Concatenate(int64_t size) const861 PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const {
862   PartialTensorShape out = *this;
863   out.AddDim(size);
864   return out;
865 }
866 
ConcatenateWithStatus(int64_t size,PartialTensorShape * out) const867 Status PartialTensorShape::ConcatenateWithStatus(
868     int64_t size, PartialTensorShape* out) const {
869   out = const_cast<PartialTensorShape*>(this);
870   return out->AddDimWithStatus(size);
871 }
872 
Concatenate(const PartialTensorShape & shape) const873 PartialTensorShape PartialTensorShape::Concatenate(
874     const PartialTensorShape& shape) const {
875   if (unknown_rank() || shape.unknown_rank()) {
876     return PartialTensorShape();
877   }
878   PartialTensorShape out = *this;
879   for (auto dim : shape) out.AddDim(dim.size);
880   return out;
881 }
882 
ConcatenateWithStatus(const PartialTensorShape & shape,PartialTensorShape * out) const883 Status PartialTensorShape::ConcatenateWithStatus(
884     const PartialTensorShape& shape, PartialTensorShape* out) const {
885   if (unknown_rank() || shape.unknown_rank()) {
886     *out = PartialTensorShape();
887     return OkStatus();
888   }
889   out = const_cast<PartialTensorShape*>(this);
890   for (auto dim : shape) {
891     Status s = out->AddDimWithStatus(dim.size);
892     if (!s.ok()) return s;
893   }
894 
895   return OkStatus();
896 }
897 
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const898 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
899                                      PartialTensorShape* result) const {
900   if (unknown_rank()) {
901     *result = shape;
902     return OkStatus();
903   }
904   if (shape.unknown_rank()) {
905     *result = *this;
906     return OkStatus();
907   }
908   const int dims_ = dims();
909   if (dims_ != shape.dims()) {
910     return errors::InvalidArgument(
911         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
912         shape.dims());
913   }
914 
915   if (result == this) {
916     return errors::Internal(
917         "PartialTensorShape::MergeWith: cannot merge shape with itself");
918   }
919 
920   result->Clear();
921   Status s = OkStatus();
922   for (int i = 0; i < dims_; ++i) {
923     const int64_t dim0 = dim_size(i);
924     const int64_t dim1 = shape.dim_size(i);
925     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
926       return errors::InvalidArgument(
927           "PartialTensorShape: Incompatible shapes during merge: ",
928           DebugString(), " vs. ", shape.DebugString());
929     }
930     s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
931     if (!s.ok()) {
932       return s;
933     }
934   }
935   return OkStatus();
936 }
937 
AsTensorShape(TensorShape * shape) const938 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
939   if (IsFullyDefined()) {
940     const TensorShapeRep* rep = this;
941     *shape = *static_cast<const TensorShape*>(rep);
942     return true;
943   }
944   return false;
945 }
946 
IsIdenticalTo(const PartialTensorShape & shape) const947 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
948   if (unknown_rank() || shape.unknown_rank()) {
949     return unknown_rank() == shape.unknown_rank();
950   }
951   if (dims() != shape.dims()) return false;
952   for (int i = 0; i < dims(); i++) {
953     if (dim_size(i) != shape.dim_size(i)) return false;
954   }
955   return true;
956 }
957 
IsCompatibleWith(const PartialTensorShape & shape) const958 bool PartialTensorShape::IsCompatibleWith(
959     const PartialTensorShape& shape) const {
960   if (unknown_rank() || shape.unknown_rank()) return true;
961   if (dims() != shape.dims()) return false;
962   for (int i = 0; i < dims(); i++) {
963     const int64_t dim0 = dim_size(i);
964     const int64_t dim1 = shape.dim_size(i);
965     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
966   }
967   return true;
968 }
969 
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)970 string PartialTensorShapeUtils::PartialShapeListString(
971     const gtl::ArraySlice<PartialTensorShape>& shapes) {
972   string result = "[";
973   bool first = true;
974   for (const PartialTensorShape& shape : shapes) {
975     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
976     first = false;
977   }
978   strings::StrAppend(&result, "]");
979   return result;
980 }
981 
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)982 bool PartialTensorShapeUtils::AreCompatible(
983     const gtl::ArraySlice<PartialTensorShape>& shapes0,
984     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
985   if (shapes0.size() == shapes1.size()) {
986     for (size_t i = 0; i < shapes0.size(); ++i) {
987       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
988         return false;
989       }
990     }
991     return true;
992   } else {
993     return false;
994   }
995 }
996 
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)997 bool PartialTensorShapeUtils::AreIdentical(
998     const gtl::ArraySlice<PartialTensorShape>& shapes0,
999     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
1000   if (shapes0.size() == shapes1.size()) {
1001     for (size_t i = 0; i < shapes0.size(); ++i) {
1002       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
1003         return false;
1004       }
1005     }
1006     return true;
1007   } else {
1008     return false;
1009   }
1010 }
1011 
NumElements(gtl::ArraySlice<int64_t> shape,int64_t * num_elements)1012 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64_t> shape,
1013                                      int64_t* num_elements) {
1014   int64_t n = 1;
1015   for (auto dim : shape) {
1016     n = MultiplyWithoutOverflow(n, dim);
1017     if (n < 0) {
1018       return errors::InvalidArgument("Can't compute total size of shape [",
1019                                      absl::StrJoin(shape, ","),
1020                                      "]; product would overflow int64");
1021     }
1022   }
1023   *num_elements = n;
1024   return OkStatus();
1025 }
1026 
1027 template class TensorShapeBase<TensorShape>;
1028 template class TensorShapeBase<PartialTensorShape>;
1029 
1030 }  // namespace tensorflow
1031