xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implementation notes:
17 //
18 // Tensor.cc uses a few templated classes and structs to facilitate
19 // implementation of the Tensor class.
20 //
21 // * Buffer<T>: provides the implementation for a typed array T[n].
22 //   The array is allocated by the given allocator. It runs T's
23 //   default constructors and destructors when T is not a simple type
24 //   (e.g., string.), and skips them otherwise.
25 //
26 // * Helper<T>: provides various routines given type T.  The routines
27 //   includes running the constructor and destructor of T[], encoding
28 //   an decoding T[] into/from a Cord, etc.
29 
30 #include "tensorflow/core/framework/tensor.h"
31 
32 #include <utility>
33 
34 #include "absl/strings/escaping.h"
35 #include "tensorflow/core/framework/allocation_description.pb.h"
36 #include "tensorflow/core/framework/log_memory.h"
37 #include "tensorflow/core/framework/resource_handle.h"
38 #include "tensorflow/core/framework/resource_handle.pb.h"
39 #include "tensorflow/core/framework/tensor.pb.h"
40 #include "tensorflow/core/framework/tensor_description.pb.h"
41 #include "tensorflow/core/framework/type_traits.h"
42 #include "tensorflow/core/framework/typed_allocator.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/framework/variant.h"
46 #include "tensorflow/core/framework/variant_encode_decode.h"
47 #include "tensorflow/core/framework/variant_op_registry.h"
48 #include "tensorflow/core/framework/variant_tensor_data.h"
49 #include "tensorflow/core/lib/core/coding.h"
50 #include "tensorflow/core/lib/core/errors.h"
51 #include "tensorflow/core/lib/core/status.h"
52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
53 #include "tensorflow/core/lib/strings/str_util.h"
54 #include "tensorflow/core/lib/strings/strcat.h"
55 #include "tensorflow/core/platform/errors.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/macros.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/platform/tensor_coding.h"
60 #include "tensorflow/core/platform/types.h"
61 
62 namespace tensorflow {
63 
64 // Allow Tensors to be stored inside Variants with automatic
65 // encoding/decoding when those Variants are themselves being decoded
66 // in a Tensor's FromProto.
67 //
68 // NOTE(mrry): The corresponding "copy function" registrations can be found in
69 // ../common_runtime/copy_tensor.cc (due to dependencies on other common_runtime
70 // code).
71 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(Tensor, "tensorflow::Tensor");
72 
GetAllocatedBytes(size_t * out_bytes) const73 bool TensorBuffer::GetAllocatedBytes(size_t* out_bytes) const {
74   AllocationDescription allocation_description;
75   FillAllocationDescription(&allocation_description);
76   if (allocation_description.allocated_bytes() > 0) {
77     *out_bytes = allocation_description.allocated_bytes();
78     return true;
79   } else {
80     return false;
81   }
82 }
83 
84 namespace {
85 
86 // An un-templated base class for Buffer.
87 class BufferBase : public TensorBuffer {
88  public:
BufferBase(Allocator * alloc,void * data_ptr)89   explicit BufferBase(Allocator* alloc, void* data_ptr)
90       : TensorBuffer(data_ptr), alloc_(alloc) {}
91 
root_buffer()92   TensorBuffer* root_buffer() override { return this; }
93 
GetAllocatedBytes(size_t * out_bytes) const94   bool GetAllocatedBytes(size_t* out_bytes) const override {
95     if (alloc_->TracksAllocationSizes()) {
96       *out_bytes = alloc_->AllocatedSize(data());
97       return *out_bytes > 0;
98     } else {
99       return false;
100     }
101   }
102 
FillAllocationDescription(AllocationDescription * proto) const103   void FillAllocationDescription(AllocationDescription* proto) const override {
104     void* data_ptr = data();
105     int64_t rb = size();
106     proto->set_requested_bytes(rb);
107     proto->set_allocator_name(alloc_->Name());
108     proto->set_ptr(reinterpret_cast<uintptr_t>(data_ptr));
109     if (alloc_->TracksAllocationSizes()) {
110       int64_t ab = alloc_->AllocatedSize(data_ptr);
111       proto->set_allocated_bytes(ab);
112       int64_t id = alloc_->AllocationId(data_ptr);
113       if (id > 0) {
114         proto->set_allocation_id(id);
115       }
116       if (RefCountIsOne()) {
117         proto->set_has_single_reference(true);
118       }
119     }
120   }
121 
122   // Returns the type of the underlying memory.
GetMemoryType() const123   AllocatorMemoryType GetMemoryType() const override {
124     return alloc_->GetMemoryType();
125   }
126 
127  protected:
RecordDeallocation()128   void RecordDeallocation() {
129     LogMemory::RecordTensorDeallocation(alloc_->AllocationId(data()),
130                                         alloc_->Name());
131   }
132 
133   Allocator* const alloc_;
134 };
135 
136 // Typed ref-counted buffer: T[n].
137 template <typename T>
138 class Buffer : public BufferBase {
139  public:
140   Buffer(Allocator* a, int64_t n);
141   Buffer(Allocator* a, int64_t n, const AllocationAttributes& allocation_attr);
142 
size() const143   size_t size() const override { return sizeof(T) * elem_; }
144 
145  private:
146   int64_t elem_;
147 
148   ~Buffer() override;
149 
150   TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
151 };
152 
LogUnexpectedSize(int64_t actual,int64_t expected)153 void LogUnexpectedSize(int64_t actual, int64_t expected) {
154   LOG(ERROR) << "Input size was " << actual << " and expected " << expected;
155 }
156 
MemoryLoggingEnabled()157 bool MemoryLoggingEnabled() {
158   static bool memory_logging_enabled = LogMemory::IsEnabled();
159   return memory_logging_enabled;
160 }
161 
162 // A set of helper functions depending on T.
163 template <typename T>
164 struct Helper {
165   // By default, we assume T is a simple type (float, int32, etc.)
166   static_assert(is_simple_type<T>::value, "T is not a simple type.");
167   typedef protobuf::RepeatedField<T> RepeatedFieldType;
168 
169   // Encoder of simple type T to a string.  We do a copy.
170   template <typename Destination>
Encodetensorflow::__anona4d3ae2f0111::Helper171   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
172     DCHECK_EQ(in->size(), sizeof(T) * n);
173     port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
174                            out);
175   }
176 
177   // Decoder of simple type T. Copy the bytes from "in" into the
178   // tensor buffer.
179   template <typename Source>
Decodetensorflow::__anona4d3ae2f0111::Helper180   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
181     if (in.size() != sizeof(T) * n) {
182       LogUnexpectedSize(in.size(), sizeof(T) * n);
183       return nullptr;
184     }
185     Buffer<T>* buf = new Buffer<T>(a, n);
186     char* data = buf->template base<char>();
187     if (data == nullptr) {
188       buf->Unref();
189       return nullptr;
190     }
191     port::CopyToArray(in, data);
192     return buf;
193   }
194 
195   // Memory usage.
TotalBytestensorflow::__anona4d3ae2f0111::Helper196   static int64_t TotalBytes(TensorBuffer* in, int64_t n) {
197     DCHECK_EQ(in->size(), sizeof(T) * n);
198     return in->size();
199   }
200 };
201 
202 // Helper specialization for string (the only non-simple type we
203 // support).
204 template <>
205 struct Helper<tstring> {
206   // Proto message uses RepeatedFieldType to hold repeated T.
207   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
208 
209   // Encodes "n" elements of type string stored in "in" into Cord
210   // "out", which is usually the TensorProto::tensor_content.
211   template <typename Destination>
Encodetensorflow::__anona4d3ae2f0111::Helper212   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
213     port::EncodeStringList(in->base<const tstring>(), n, out);
214   }
215 
216   // Decodes "n" elements of type string from "in" and constructs a
217   // buffer out of it. Returns nullptr if the decoding fails. "in" is
218   // usually the TensorProto::tensor_content.
219   template <typename Source>
Decodetensorflow::__anona4d3ae2f0111::Helper220   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
221     Buffer<tstring>* buf = new Buffer<tstring>(a, n);
222     tstring* strings = buf->template base<tstring>();
223     if (strings == nullptr || !port::DecodeStringList(in, strings, n)) {
224       buf->Unref();
225       return nullptr;
226     }
227     return buf;
228   }
229 
230   // Returns the estimated memory usage of "n" elements of type T
231   // stored in buffer "in".
TotalBytestensorflow::__anona4d3ae2f0111::Helper232   static int64_t TotalBytes(TensorBuffer* in, int n) {
233     int64_t tot = in->size();
234     DCHECK_EQ(tot, sizeof(tstring) * n);
235     const tstring* p = in->base<const tstring>();
236     for (int i = 0; i < n; ++i, ++p) tot += p->size();
237     return tot;
238   }
239 };
240 
241 template <>
242 struct Helper<ResourceHandle> {
243   // Proto message uses RepeatedFieldType to hold repeated T.
244   typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
245 
246   // Encodes "n" elements of type ResourceHandle stored in "in" into destination
247   // "out", which is usually the TensorProto::tensor_content.
248   template <typename Destination>
Encodetensorflow::__anona4d3ae2f0111::Helper249   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
250     EncodeResourceHandleList(in->base<const ResourceHandle>(), n,
251                              port::NewStringListEncoder(out));
252   }
253 
254   // Decodes "n" elements of type string from "in" and constructs a
255   // buffer out of it. Returns nullptr if the decoding fails. "in" is
256   // usually the TensorProto::tensor_content.
257   template <typename Source>
Decodetensorflow::__anona4d3ae2f0111::Helper258   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
259     auto* buf = new Buffer<ResourceHandle>(a, n);
260     ResourceHandle* ps = buf->template base<ResourceHandle>();
261     if (ps == nullptr ||
262         !DecodeResourceHandleList(port::NewStringListDecoder(in), ps, n)) {
263       buf->Unref();
264       return nullptr;
265     }
266     return buf;
267   }
268 
269   // Returns the estimated memory usage of "n" elements of type T
270   // stored in buffer "in".
TotalBytestensorflow::__anona4d3ae2f0111::Helper271   static int64_t TotalBytes(TensorBuffer* in, int n) {
272     return n * sizeof(ResourceHandle);
273   }
274 };
275 
276 template <>
277 struct Helper<Variant> {
278   // Encodes "n" elements of type Variant stored in "in" into destination
279   // "out", which is usually the TensorProto::tensor_content.
280   template <typename Destination>
Encodetensorflow::__anona4d3ae2f0111::Helper281   static void Encode(TensorBuffer* in, int64_t n, Destination* out) {
282     EncodeVariantList(in->base<const Variant>(), n,
283                       port::NewStringListEncoder(out));
284   }
285 
286   // Decodes "n" elements of type Variant from "in" and constructs a
287   // buffer out of it. Returns nullptr if the decoding fails. "in" is
288   // usually the TensorProto::tensor_content.
289   template <typename Source>
Decodetensorflow::__anona4d3ae2f0111::Helper290   static TensorBuffer* Decode(Allocator* a, const Source& in, int64_t n) {
291     auto* buf = new Buffer<Variant>(a, n);
292     Variant* ps = buf->template base<Variant>();
293     if (ps == nullptr ||
294         !DecodeVariantList(port::NewStringListDecoder(in), ps, n)) {
295       buf->Unref();
296       return nullptr;
297     }
298     return buf;
299   }
300 
301   // Returns the estimated memory usage of "n" elements of type T
302   // stored in buffer "in".
TotalBytestensorflow::__anona4d3ae2f0111::Helper303   static int64_t TotalBytes(TensorBuffer* in, int n) {
304     return n * sizeof(Variant);
305   }
306 };
307 
308 template <typename T>
309 struct ProtoHelper {};
310 
311 // For a C++ type "T" (float, double, int32, etc.), the repeated field
312 // "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
313 // int32, string, etc) in the TensorProto is used for serializing the
314 // tensor of type "T".
315 #define PROTO_TRAITS(T, F, N)                                          \
316   template <>                                                          \
317   struct ProtoHelper<T> {                                              \
318     typedef Helper<F>::RepeatedFieldType FieldType;                    \
319     static FieldType::const_iterator Begin(const TensorProto& proto) { \
320       return proto.N##_val().begin();                                  \
321     }                                                                  \
322     static size_t NumElements(const TensorProto& proto) {              \
323       return proto.N##_val().size();                                   \
324     }                                                                  \
325     static void Fill(const T* data, size_t n, TensorProto* proto) {    \
326       typename ProtoHelper<T>::FieldType copy(data, data + n);         \
327       proto->mutable_##N##_val()->Swap(&copy);                         \
328     }                                                                  \
329   };
330 PROTO_TRAITS(float, float, float);
331 PROTO_TRAITS(double, double, double);
332 PROTO_TRAITS(int32, int32, int);
333 PROTO_TRAITS(uint8, int32, int);
334 PROTO_TRAITS(uint16, int32, int);
335 PROTO_TRAITS(uint32, uint32, uint32);
336 PROTO_TRAITS(int16, int32, int);
337 PROTO_TRAITS(int8, int32, int);
338 PROTO_TRAITS(bool, bool, bool);
339 PROTO_TRAITS(tstring, tstring, string);
340 PROTO_TRAITS(qint8, int32, int);
341 PROTO_TRAITS(quint8, int32, int);
342 PROTO_TRAITS(qint16, int32, int);
343 PROTO_TRAITS(quint16, int32, int);
344 #undef PROTO_TRAITS
345 
346 template <>
347 struct ProtoHelper<int64_t> {
Begintensorflow::__anona4d3ae2f0111::ProtoHelper348   static protobuf::RepeatedField<int64_t>::const_iterator Begin(
349       const TensorProto& proto) {
350     return proto.int64_val().begin();
351   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper352   static size_t NumElements(const TensorProto& proto) {
353     return proto.int64_val().size();
354   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper355   static void Fill(const int64_t* data, size_t n, TensorProto* proto) {
356     protobuf::RepeatedField<protobuf_int64> copy(data, data + n);
357     proto->mutable_int64_val()->Swap(&copy);
358   }
359 };
360 
361 template <>
362 struct ProtoHelper<uint64> {
Begintensorflow::__anona4d3ae2f0111::ProtoHelper363   static protobuf::RepeatedField<uint64_t>::const_iterator Begin(
364       const TensorProto& proto) {
365     return proto.uint64_val().begin();
366   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper367   static size_t NumElements(const TensorProto& proto) {
368     return proto.uint64_val().size();
369   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper370   static void Fill(const uint64* data, size_t n, TensorProto* proto) {
371     protobuf::RepeatedField<protobuf_uint64> copy(data, data + n);
372     proto->mutable_uint64_val()->Swap(&copy);
373   }
374 };
375 
376 template <>
377 struct ProtoHelper<ResourceHandle> {
Begintensorflow::__anona4d3ae2f0111::ProtoHelper378   static protobuf::RepeatedPtrField<ResourceHandleProto>::const_iterator Begin(
379       const TensorProto& proto) {
380     return proto.resource_handle_val().begin();
381   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper382   static size_t NumElements(const TensorProto& proto) {
383     return proto.resource_handle_val().size();
384   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper385   static void Fill(const ResourceHandle* data, size_t n, TensorProto* proto) {
386     auto* handles = proto->mutable_resource_handle_val();
387     handles->Clear();
388     for (size_t i = 0; i < n; i++) {
389       data[i].AsProto(handles->Add());
390     }
391   }
392 };
393 
394 template <>
395 struct ProtoHelper<Variant> {
396   static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
Begintensorflow::__anona4d3ae2f0111::ProtoHelper397   Begin(const TensorProto& proto) {
398     return proto.variant_val().begin();
399   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper400   static size_t NumElements(const TensorProto& proto) {
401     return proto.variant_val().size();
402   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper403   static void Fill(const Variant* data, size_t n, TensorProto* proto) {
404     auto* variant_values = proto->mutable_variant_val();
405     variant_values->Clear();
406     for (size_t i = 0; i < n; ++i) {
407       VariantTensorData tmp;
408       data[i].Encode(&tmp);
409       tmp.ToProto(variant_values->Add());
410     }
411   }
412 };
413 
414 template <>
415 struct ProtoHelper<complex64> {
416   typedef Helper<float>::RepeatedFieldType FieldType;
Begintensorflow::__anona4d3ae2f0111::ProtoHelper417   static const complex64* Begin(const TensorProto& proto) {
418     return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
419   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper420   static size_t NumElements(const TensorProto& proto) {
421     return proto.scomplex_val().size() / 2;
422   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper423   static void Fill(const complex64* data, size_t n, TensorProto* proto) {
424     const float* p = reinterpret_cast<const float*>(data);
425     FieldType copy(p, p + n * 2);
426     proto->mutable_scomplex_val()->Swap(&copy);
427   }
428 };
429 
430 template <>
431 struct ProtoHelper<complex128> {
432   typedef Helper<double>::RepeatedFieldType FieldType;
Begintensorflow::__anona4d3ae2f0111::ProtoHelper433   static const complex128* Begin(const TensorProto& proto) {
434     return reinterpret_cast<const complex128*>(proto.dcomplex_val().data());
435   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper436   static size_t NumElements(const TensorProto& proto) {
437     return proto.dcomplex_val().size() / 2;
438   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper439   static void Fill(const complex128* data, size_t n, TensorProto* proto) {
440     const double* p = reinterpret_cast<const double*>(data);
441     FieldType copy(p, p + n * 2);
442     proto->mutable_dcomplex_val()->Swap(&copy);
443   }
444 };
445 
446 template <>
447 struct ProtoHelper<qint32> {
448   typedef Helper<int32>::RepeatedFieldType FieldType;
Begintensorflow::__anona4d3ae2f0111::ProtoHelper449   static const qint32* Begin(const TensorProto& proto) {
450     return reinterpret_cast<const qint32*>(proto.int_val().data());
451   }
NumElementstensorflow::__anona4d3ae2f0111::ProtoHelper452   static size_t NumElements(const TensorProto& proto) {
453     return proto.int_val().size();
454   }
Filltensorflow::__anona4d3ae2f0111::ProtoHelper455   static void Fill(const qint32* data, size_t n, TensorProto* proto) {
456     const int32* p = reinterpret_cast<const int32*>(data);
457     FieldType copy(p, p + n);
458     proto->mutable_int_val()->Swap(&copy);
459   }
460 };
461 
462 template <>
463 struct ProtoHelper<bfloat16> {
Filltensorflow::__anona4d3ae2f0111::ProtoHelper464   static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
465     proto->mutable_half_val()->Reserve(n);
466     for (size_t i = 0; i < n; ++i) {
467       proto->mutable_half_val()->AddAlreadyReserved(
468           Eigen::numext::bit_cast<uint16>(data[i]));
469     }
470   }
471 };
472 
473 template <>
474 struct ProtoHelper<Eigen::half> {
Filltensorflow::__anona4d3ae2f0111::ProtoHelper475   static void Fill(const Eigen::half* data, size_t n, TensorProto* proto) {
476     proto->mutable_half_val()->Reserve(n);
477     for (size_t i = 0; i < n; ++i) {
478       proto->mutable_half_val()->AddAlreadyReserved(
479           Eigen::numext::bit_cast<uint16>(data[i]));
480     }
481   }
482 };
483 
484 template <typename T>
Buffer(Allocator * a,int64_t n)485 Buffer<T>::Buffer(Allocator* a, int64_t n)
486     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
487       elem_(n) {}
488 
489 template <typename T>
Buffer(Allocator * a,int64_t n,const AllocationAttributes & allocation_attr)490 Buffer<T>::Buffer(Allocator* a, int64_t n,
491                   const AllocationAttributes& allocation_attr)
492     : BufferBase(a, TypedAllocator::Allocate<T>(a, n, allocation_attr)),
493       elem_(n) {}
494 
495 template <typename T>
~Buffer()496 Buffer<T>::~Buffer() {
497   if (data()) {
498     if (MemoryLoggingEnabled()) {
499       RecordDeallocation();
500     }
501     TypedAllocator::Deallocate<T>(alloc_, static_cast<T*>(data()), elem_);
502   }
503 }
504 
505 // Allocates a T[n] buffer. Fills in the buffer with repeated values
506 // in "in".  If "in" has less values than "n", fills the rest of T[n]
507 // with the last value. If "in" has no values, fills T[n] with the
508 // default value for T.
509 //
510 // This routine is using the typed fields (float_val, etc.) in the
511 // tensor proto as opposed to the untyped binary representation
512 // (tensor_content). This is used when we expect the TensorProto is
513 // used by a client program which may not know how to encode a tensor
514 // in the compact binary representation.
515 template <typename T>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)516 TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64_t n) {
517   CHECK_GT(n, 0);
518   Buffer<T>* buf = new Buffer<T>(a, n);
519   T* data = buf->template base<T>();
520   if (data == nullptr) {
521     buf->Unref();
522     return nullptr;
523   }
524 
525   const int64_t in_n = ProtoHelper<T>::NumElements(in);
526   if (in_n <= 0) {
527     std::fill_n(data, n, T());
528   } else {
529     auto begin = ProtoHelper<T>::Begin(in);
530     if (n <= in_n) {
531       std::copy_n(begin, n, data);
532     } else {
533       std::copy_n(begin, in_n, data);
534       if (std::is_trivially_copyable<T>::value) {
535         const T last = *(data + in_n - 1);
536         std::fill_n(data + in_n, n - in_n, last);
537       } else {
538         const T& last = *(data + in_n - 1);
539         std::fill_n(data + in_n, n - in_n, last);
540       }
541     }
542   }
543 
544   return buf;
545 }
546 
547 // Separate implementation for `ResourceHandle` to handle the case when the
548 // proto for the resource is invalid. See `resource_handle.h` constructor and
549 // static factory builder.
550 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)551 TensorBuffer* FromProtoField<ResourceHandle>(Allocator* a,
552                                              const TensorProto& in, int64_t n) {
553   CHECK_GT(n, 0);
554   Buffer<ResourceHandle>* buf = new Buffer<ResourceHandle>(a, n);
555   ResourceHandle* data = buf->template base<ResourceHandle>();
556   if (data == nullptr) {
557     buf->Unref();
558     return nullptr;
559   }
560   const int64_t in_n = ProtoHelper<ResourceHandle>::NumElements(in);
561   if (in_n <= 0) {
562     std::fill_n(data, n, ResourceHandle());
563   } else {
564     // If tensor shape says we have n < in_n elements in the output tensor
565     // then make sure to only decode the first n out of the in_n elements in the
566     // in tensors. In all other cases, we decode all in_n elements of in and set
567     // the remaining elements up to n to be the default ResourceHandle() value.
568     const int64_t real_n = n < in_n ? n : in_n;
569     for (int64_t i = 0; i < real_n; ++i) {
570       Status s = ResourceHandle::BuildResourceHandle(in.resource_handle_val(i),
571                                                      &data[i]);
572       if (!s.ok()) {
573         LOG(ERROR) << "Could not decode resource handle from proto \""
574                    << in.resource_handle_val(i).ShortDebugString()
575                    << "\", returned status: " << s.ToString();
576         buf->Unref();
577         return nullptr;
578       }
579     }
580     for (int64_t i = in_n; i < n; ++i) {
581       data[i] = ResourceHandle();
582     }
583   }
584   return buf;
585 }
586 
587 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)588 TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
589                                       int64_t n) {
590   CHECK_GT(n, 0);
591   Buffer<Variant>* buf = new Buffer<Variant>(a, n);
592   Variant* data = buf->template base<Variant>();
593   if (data == nullptr) {
594     buf->Unref();
595     return nullptr;
596   }
597   const int64_t in_n = ProtoHelper<Variant>::NumElements(in);
598   if (in_n <= 0) {
599     std::fill_n(data, n, Variant());
600   } else {
601     // If tensor shape says we have n < in_n elements in the output tensor
602     // then make sure to only decode the first n out of the in_n elements in the
603     // in tensors. In all other cases, we decode all in_n elements of in and set
604     // the remaining elements up to n to be the default Variant() value.
605     const int64_t real_n = n < in_n ? n : in_n;
606     for (int64_t i = 0; i < real_n; ++i) {
607       data[i] = in.variant_val(i);
608       if (!DecodeUnaryVariant(&data[i])) {
609         LOG(ERROR) << "Could not decode variant with type_name: \""
610                    << data[i].TypeName()
611                    << "\".  Perhaps you forgot to register a "
612                       "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
613         buf->Unref();
614         return nullptr;
615       }
616     }
617     for (int64_t i = in_n; i < n; ++i) {
618       data[i] = Variant();
619     }
620   }
621   return buf;
622 }
623 
624 // fp16 and bfloat16 are opaque to the protobuf, so we deserialize these
625 // identical to uint16 but with data stored in half_val instead of int_val (ie.,
626 // we don't use ProtoHelper<uint16>).
627 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)628 TensorBuffer* FromProtoField<Eigen::half>(Allocator* a, const TensorProto& in,
629                                           int64_t n) {
630   CHECK_GT(n, 0);
631   Buffer<Eigen::half>* buf = new Buffer<Eigen::half>(a, n);
632   uint16* data = buf->template base<uint16>();
633   if (data == nullptr) {
634     buf->Unref();
635     return nullptr;
636   }
637   const int64_t in_n = in.half_val().size();
638   auto begin = in.half_val().begin();
639   if (n <= in_n) {
640     std::copy_n(begin, n, data);
641   } else if (in_n > 0) {
642     std::copy_n(begin, in_n, data);
643     const uint16 last = *(data + in_n - 1);
644     std::fill_n(data + in_n, n - in_n, last);
645   } else {
646     std::fill_n(data, n, 0);
647   }
648   return buf;
649 }
650 
651 template <>
FromProtoField(Allocator * a,const TensorProto & in,int64_t n)652 TensorBuffer* FromProtoField<bfloat16>(Allocator* a, const TensorProto& in,
653                                        int64_t n) {
654   CHECK_GT(n, 0);
655   Buffer<bfloat16>* buf = new Buffer<bfloat16>(a, n);
656   uint16* data = buf->template base<uint16>();
657   if (data == nullptr) {
658     buf->Unref();
659     return nullptr;
660   }
661   const int64_t in_n = in.half_val().size();
662   auto begin = in.half_val().begin();
663   if (n <= in_n) {
664     std::copy_n(begin, n, data);
665   } else if (in_n > 0) {
666     std::copy_n(begin, in_n, data);
667     const uint16 last = *(data + in_n - 1);
668     std::fill_n(data + in_n, n - in_n, last);
669   } else {
670     std::fill_n(data, n, 0);
671   }
672   return buf;
673 }
674 
675 // Copies T[n] stored in the buffer "in" into the repeated field in
676 // "out" corresponding to type T.
677 template <typename T>
ToProtoField(const TensorBuffer & in,int64_t n,TensorProto * out)678 void ToProtoField(const TensorBuffer& in, int64_t n, TensorProto* out) {
679   const T* data = in.base<const T>();
680   // NOTE: T may not the same as
681   // ProtoHelper<T>::FieldType::value_type.  E.g., T==int16,
682   // ProtoHelper<T>::FieldType::value_type==int32.  If performance is
683   // critical, we can specialize T=float and do memcpy directly.
684   ProtoHelper<T>::Fill(data, n, out);
685 }
686 
RefIfNonNull(core::RefCounted * buf)687 void RefIfNonNull(core::RefCounted* buf) {
688   if (buf) buf->Ref();
689 }
690 
UnrefIfNonNull(core::RefCounted * buf)691 void UnrefIfNonNull(core::RefCounted* buf) {
692   if (buf) buf->Unref();
693 }
694 
695 }  // end namespace
696 
Tensor()697 Tensor::Tensor() : Tensor(DT_FLOAT) {}
698 
Tensor(DataType type)699 Tensor::Tensor(DataType type) : shape_(type), buf_(nullptr) {}
700 
Tensor(DataType type,const TensorShape & shape,TensorBuffer * buf)701 Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
702     : shape_(shape), buf_(buf) {
703   set_dtype(type);
704   RefIfNonNull(buf);
705 }
706 
Tensor(DataType type,TensorShape shape,core::RefCountPtr<TensorBuffer> buf)707 Tensor::Tensor(DataType type, TensorShape shape,
708                core::RefCountPtr<TensorBuffer> buf)
709     : shape_(std::move(shape)), buf_(buf.release()) {
710   set_dtype(type);
711 }
712 
IsInitialized() const713 bool Tensor::IsInitialized() const {
714   return (buf_ != nullptr && buf_->data() != nullptr) ||
715          shape_.num_elements() == 0;
716 }
717 
CheckType(DataType expected_dtype) const718 void Tensor::CheckType(DataType expected_dtype) const {
719   CHECK_EQ(dtype(), expected_dtype)
720       << " " << DataTypeString(expected_dtype) << " expected, got "
721       << DataTypeString(dtype());
722 }
723 
CheckTypeAndIsAligned(DataType expected_dtype) const724 void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
725   CHECK_EQ(dtype(), expected_dtype)
726       << " " << DataTypeString(expected_dtype) << " expected, got "
727       << DataTypeString(dtype());
728   CHECK(IsAligned()) << "ptr = " << base<void>();
729 }
730 
CheckIsAlignedAndSingleElement() const731 void Tensor::CheckIsAlignedAndSingleElement() const {
732   CHECK(IsAligned()) << "Aligned and single element";
733   CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
734 }
735 
~Tensor()736 Tensor::~Tensor() { UnrefIfNonNull(buf_); }
737 
BitcastFrom(const Tensor & other,DataType dtype,const TensorShape & shape)738 Status Tensor::BitcastFrom(const Tensor& other, DataType dtype,
739                            const TensorShape& shape) {
740   int in_size = DataTypeSize(other.dtype());
741   int out_size = DataTypeSize(dtype);
742   if (in_size == 0) {
743     return errors::InvalidArgument("other tensor has zero-sized data type");
744   }
745   if (out_size == 0) {
746     return errors::InvalidArgument("specified output type is zero-sized");
747   }
748   if (shape.num_elements() * out_size !=
749       other.shape().num_elements() * in_size) {
750     return errors::InvalidArgument(
751         "input and output shapes/data type sizes are not compatible");
752   }
753   shape_ = shape;
754   shape_.set_data_type(dtype);
755   if (buf_ != other.buf_) {
756     UnrefIfNonNull(buf_);
757     buf_ = other.buf_;
758     RefIfNonNull(buf_);
759   }
760   return OkStatus();
761 }
762 
763 // Notice that buf_ either points to a regular TensorBuffer or a SubBuffer.
764 // For the latter case, we have to make sure that the refcount is
765 // one both for the SubBuffer _and_ the underlying TensorBuffer.
RefCountIsOne() const766 bool Tensor::RefCountIsOne() const {
767   return buf_ != nullptr && buf_->RefCountIsOne() &&
768          buf_->root_buffer()->RefCountIsOne() && buf_->OwnsMemory();
769 }
770 
771 // The macro CASES() expands to a switch statement conditioned on
772 // TYPE_ENUM. Each case expands the STMTS after a typedef for T.
773 #define SINGLE_ARG(...) __VA_ARGS__
774 #define CASE(TYPE, STMTS)               \
775   case DataTypeToEnum<TYPE>::value: {   \
776     typedef TF_ATTRIBUTE_UNUSED TYPE T; \
777     STMTS;                              \
778     break;                              \
779   }
780 #define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
781   switch (TYPE_ENUM) {                                         \
782     CASE(float, SINGLE_ARG(STMTS))                             \
783     CASE(double, SINGLE_ARG(STMTS))                            \
784     CASE(int32, SINGLE_ARG(STMTS))                             \
785     CASE(uint8, SINGLE_ARG(STMTS))                             \
786     CASE(uint16, SINGLE_ARG(STMTS))                            \
787     CASE(uint32, SINGLE_ARG(STMTS))                            \
788     CASE(uint64, SINGLE_ARG(STMTS))                            \
789     CASE(int16, SINGLE_ARG(STMTS))                             \
790     CASE(int8, SINGLE_ARG(STMTS))                              \
791     CASE(tstring, SINGLE_ARG(STMTS))                           \
792     CASE(complex64, SINGLE_ARG(STMTS))                         \
793     CASE(complex128, SINGLE_ARG(STMTS))                        \
794     CASE(int64_t, SINGLE_ARG(STMTS))                           \
795     CASE(bool, SINGLE_ARG(STMTS))                              \
796     CASE(qint32, SINGLE_ARG(STMTS))                            \
797     CASE(quint8, SINGLE_ARG(STMTS))                            \
798     CASE(qint8, SINGLE_ARG(STMTS))                             \
799     CASE(quint16, SINGLE_ARG(STMTS))                           \
800     CASE(qint16, SINGLE_ARG(STMTS))                            \
801     CASE(bfloat16, SINGLE_ARG(STMTS))                          \
802     CASE(Eigen::half, SINGLE_ARG(STMTS))                       \
803     CASE(ResourceHandle, SINGLE_ARG(STMTS))                    \
804     CASE(Variant, SINGLE_ARG(STMTS))                           \
805     case DT_INVALID:                                           \
806       INVALID;                                                 \
807       break;                                                   \
808     default:                                                   \
809       DEFAULT;                                                 \
810       break;                                                   \
811   }
812 
813 #define CASES(TYPE_ENUM, STMTS)                                      \
814   CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
815                      , LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
816 
Tensor(Allocator * a,DataType type,const TensorShape & shape)817 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
818     : shape_(shape), buf_(nullptr) {
819   set_dtype(type);
820   CHECK_NOTNULL(a);
821   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
822     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
823   }
824   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
825     LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
826                                       *this);
827   }
828 }
829 
Tensor(Allocator * a,DataType type,const TensorShape & shape,const AllocationAttributes & allocation_attr)830 Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
831                const AllocationAttributes& allocation_attr)
832     : shape_(shape), buf_(nullptr) {
833   set_dtype(type);
834   CHECK_NOTNULL(a);
835   if (shape_.num_elements() > 0 || a->AllocatesOpaqueHandle()) {
836     CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
837   }
838   if (MemoryLoggingEnabled() && !allocation_attr.allocation_will_be_logged &&
839       buf_ != nullptr && buf_->data() != nullptr) {
840     LogMemory::RecordTensorAllocation("Unknown (with attributes)",
841                                       LogMemory::UNKNOWN_STEP_ID, *this);
842   }
843 }
844 
BuildTensor(DataType type,const TensorShape & shape,Tensor * out_tensor)845 Status Tensor::BuildTensor(DataType type, const TensorShape& shape,
846                            Tensor* out_tensor) {
847   // Avoid crashes due to invalid or unsupported types.
848   CASES_WITH_DEFAULT(
849       type, {}, return errors::InvalidArgument("Type not set"),
850       return errors::InvalidArgument("Unexpected type: ", DataType_Name(type)));
851   *out_tensor = Tensor(type, shape);
852   return OkStatus();
853 }
854 
855 // NOTE(mrry): The default allocator for a Tensor (when none is specified) is
856 // the default CPU allocator for NUMA zone 0. Accessing that currently involves
857 // acquiring a lock, which guards initialization of the per-NUMA zone
858 // allocators, and becomes highly contended.
859 //
860 // Note also that it would be better if all Tensor allocations required the user
861 // to specify an allocator, for purposes of accounting, etc. However, the
862 // default allocator is widely used throughout the codebase and in client code.
get_default_cpu_allocator()863 static Allocator* get_default_cpu_allocator() {
864   static Allocator* default_cpu_allocator =
865       cpu_allocator(port::kNUMANoAffinity);
866   return default_cpu_allocator;
867 }
868 
Tensor(DataType type,const TensorShape & shape)869 Tensor::Tensor(DataType type, const TensorShape& shape)
870     : Tensor(get_default_cpu_allocator(), type, shape) {}
871 
GetAllocatedBytes(size_t * out_bytes) const872 bool Tensor::HostScalarTensorBufferBase::GetAllocatedBytes(
873     size_t* out_bytes) const {
874   // `this->FillAllocationDescription()` never sets allocated bytes information,
875   // so we can short-circuit the construction of an `AllocationDescription`.
876   return false;
877 }
878 
FillAllocationDescription(AllocationDescription * proto) const879 void Tensor::HostScalarTensorBufferBase::FillAllocationDescription(
880     AllocationDescription* proto) const {
881   proto->set_requested_bytes(size());
882   proto->set_allocator_name("HostScalarTensorBuffer");
883   proto->set_ptr(reinterpret_cast<uintptr_t>(data()));
884 }
885 
886 template <typename T>
887 class SubBuffer : public TensorBuffer {
888  public:
889   // This buffer is an alias to buf[delta, delta + n).
SubBuffer(TensorBuffer * buf,int64_t delta,int64_t n)890   SubBuffer(TensorBuffer* buf, int64_t delta, int64_t n)
891       : TensorBuffer(buf->base<T>() + delta),
892         root_(buf->root_buffer()),
893         elem_(n) {
894     // Sanity check. The caller should ensure the sub buffer is valid.
895     CHECK_LE(root_->base<T>(), this->base<T>());
896     T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
897     CHECK_LE(this->base<T>(), root_limit);
898     CHECK_LE(this->base<T>() + n, root_limit);
899     // Hold a ref of the underlying root buffer.
900     // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
901     root_->Ref();
902   }
903 
size() const904   size_t size() const override { return sizeof(T) * elem_; }
root_buffer()905   TensorBuffer* root_buffer() override { return root_; }
GetAllocatedBytes(size_t * out_bytes) const906   bool GetAllocatedBytes(size_t* out_bytes) const override {
907     return root_->GetAllocatedBytes(out_bytes);
908   }
FillAllocationDescription(AllocationDescription * proto) const909   void FillAllocationDescription(AllocationDescription* proto) const override {
910     root_->FillAllocationDescription(proto);
911   }
912 
913  private:
914   TensorBuffer* root_;
915   int64_t elem_;
916 
~SubBuffer()917   ~SubBuffer() override { root_->Unref(); }
918 
919   TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
920 };
921 
Slice(int64_t start,int64_t limit) const922 Tensor Tensor::Slice(int64_t start, int64_t limit) const {
923   CHECK_GE(dims(), 1);
924   CHECK_LE(0, start);
925   CHECK_LE(start, limit);
926   int64_t dim0_size = shape_.dim_size(0);
927   CHECK_LE(limit, dim0_size);
928   if ((start == 0) && (limit == dim0_size)) {
929     return *this;
930   }
931   Tensor ret;
932   ret.shape_ = shape_;
933   ret.set_dtype(dtype());
934   ret.buf_ = nullptr;
935   if (dim0_size > 0) {
936     const int64_t elems_per_dim0 = NumElements() / dim0_size;
937     const int64_t delta = start * elems_per_dim0;
938     dim0_size = limit - start;
939     ret.shape_.set_dim(0, dim0_size);
940     const int64_t num_elems = dim0_size * elems_per_dim0;
941     if (buf_) {
942       DataType dt = dtype();
943       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
944     }
945   }
946   return ret;
947 }
948 
SubSlice(int64_t index) const949 Tensor Tensor::SubSlice(int64_t index) const {
950   CHECK_GE(dims(), 1);  // Crash ok.
951   CHECK_LE(0, index);   // Crash ok.
952   int64_t dim0_size = shape_.dim_size(0);
953   CHECK_LE(index, dim0_size);  // Crash ok.
954   Tensor ret;
955   ret.shape_ = shape_;
956   ret.shape_.RemoveDim(0);
957   ret.set_dtype(dtype());
958   ret.buf_ = nullptr;
959   if (dim0_size > 0) {
960     const int64_t elems_per_dim0 = NumElements() / dim0_size;
961     const int64_t delta = index * elems_per_dim0;
962     const int64_t num_elems = elems_per_dim0;
963     if (buf_) {
964       DataType dt = dtype();
965       CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
966     }
967   }
968   return ret;
969 }
970 
FromProto(const TensorProto & proto)971 bool Tensor::FromProto(const TensorProto& proto) {
972   return FromProto(get_default_cpu_allocator(), proto);
973 }
974 
FromProto(Allocator * a,const TensorProto & proto)975 bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
976   CHECK_NOTNULL(a);
977   TensorBuffer* p = nullptr;
978   if (!TensorShape::IsValid(proto.tensor_shape())) return false;
979   if (proto.dtype() == DT_INVALID) return false;
980   TensorShape shape(proto.tensor_shape());
981   const int64_t N = shape.num_elements();
982   if (N > 0 && proto.dtype()) {
983     bool dtype_error = false;
984     if (!proto.tensor_content().empty()) {
985       const auto& content = proto.tensor_content();
986       CASES_WITH_DEFAULT(proto.dtype(), p = Helper<T>::Decode(a, content, N),
987                          dtype_error = true, dtype_error = true);
988     } else {
989       CASES_WITH_DEFAULT(proto.dtype(), p = FromProtoField<T>(a, proto, N),
990                          dtype_error = true, dtype_error = true);
991     }
992     if (dtype_error || p == nullptr) return false;
993   } else {
994     // Handle the case of empty tensors (N = 0) or tensors with incomplete shape
995     // (N = -1). All other values of `shape.num_elements()` should be invalid by
996     // construction.
997     // Here, we just need to validate that the `proto.dtype()` value is valid.
998     bool dtype_error = false;
999     CASES_WITH_DEFAULT(proto.dtype(), break, dtype_error = true,
1000                        dtype_error = true);
1001     if (dtype_error) return false;
1002   }
1003   shape_ = shape;
1004   set_dtype(proto.dtype());
1005   UnrefIfNonNull(buf_);
1006   buf_ = p;
1007   // TODO(misard) add tracking of which kernels and steps are calling
1008   // FromProto.
1009   if (MemoryLoggingEnabled() && buf_ != nullptr && buf_->data() != nullptr) {
1010     LogMemory::RecordTensorAllocation("Unknown (from Proto)",
1011                                       LogMemory::UNKNOWN_STEP_ID, *this);
1012   }
1013   return true;
1014 }
1015 
AsProtoField(TensorProto * proto) const1016 void Tensor::AsProtoField(TensorProto* proto) const {
1017   proto->Clear();
1018   shape_.AsProto(proto->mutable_tensor_shape());
1019   proto->set_dtype(dtype());
1020   if (buf_) {
1021     CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
1022   }
1023 }
1024 
AsProtoTensorContent(TensorProto * proto) const1025 void Tensor::AsProtoTensorContent(TensorProto* proto) const {
1026   proto->Clear();
1027   proto->set_dtype(dtype());
1028   shape_.AsProto(proto->mutable_tensor_shape());
1029   if (buf_) {
1030     CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
1031                                      proto->mutable_tensor_content()));
1032   }
1033 }
1034 
TotalBytes() const1035 size_t Tensor::TotalBytes() const {
1036   if (shape_.num_elements() == 0) return 0;
1037   CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
1038   CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
1039   return 0;  // Makes compiler happy.
1040 }
1041 
AllocatedBytes() const1042 size_t Tensor::AllocatedBytes() const {
1043   if (buf_) {
1044     size_t ret;
1045     if (buf_->GetAllocatedBytes(&ret)) {
1046       return ret;
1047     }
1048   }
1049   return TotalBytes();
1050 }
1051 
CanUseDMA() const1052 bool Tensor::CanUseDMA() const {
1053   CASES(dtype(), return is_simple_type<T>::value);
1054   return false;  // Makes compiler happy.
1055 }
1056 
1057 #undef CASES
1058 #undef CASE
1059 
1060 namespace {
1061 
1062 // StrCat and StrAppend don't support Eigen::half directly at the moment, and
1063 // we would like to keep them compatible with their absl counterparts, for ease
1064 // of migration. We could rely on errors::internal::PrepareForStrCat() but the
1065 // logic is so simple we can just replicate it here, where it is close to its
1066 // usage and easy to change later. And there's the extra benefit of not
1067 // accessing an 'internal' namespace.
PrintOneElement(const strings::AlphaNum & a,bool print_v2)1068 inline const strings::AlphaNum& PrintOneElement(const strings::AlphaNum& a,
1069                                                 bool print_v2) {
1070   return a;
1071 }
PrintOneElement(const tstring & a,bool print_v2)1072 inline string PrintOneElement(const tstring& a, bool print_v2) {
1073   if (print_v2) {
1074     return "\"" + absl::Utf8SafeCEscape(a) + "\"";
1075   } else {
1076     return absl::Utf8SafeCEscape(a);
1077   }
1078 }
PrintOneElement(const Eigen::half & h,bool print_v2)1079 inline float PrintOneElement(const Eigen::half& h, bool print_v2) {
1080   return static_cast<float>(h);
1081 }
1082 
PrintOneElement(bfloat16 f,bool print_v2)1083 inline float PrintOneElement(bfloat16 f, bool print_v2) {
1084   return static_cast<float>(f);
1085 }
1086 
1087 // Print from left dim to right dim recursively.
1088 template <typename T>
PrintOneDim(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t limit,int shape_size,const T * data,int64_t * data_index,string * result)1089 void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1090                  int64_t limit, int shape_size, const T* data,
1091                  int64_t* data_index, string* result) {
1092   if (*data_index >= limit) return;
1093   int64_t element_count = shape[dim_index];
1094   // We have reached the right-most dimension of the tensor.
1095   if (dim_index == shape_size - 1) {
1096     for (int64_t i = 0; i < element_count; i++) {
1097       if (*data_index >= limit) {
1098         // If not enough elements has been printed, append "...".
1099         if (dim_index != 0) {
1100           strings::StrAppend(result, "...");
1101         }
1102         return;
1103       }
1104       if (i > 0) strings::StrAppend(result, " ");
1105       strings::StrAppend(result, PrintOneElement(data[(*data_index)++], false));
1106     }
1107     return;
1108   }
1109   // Loop every element of one dim.
1110   for (int64_t i = 0; i < element_count; i++) {
1111     bool flag = false;
1112     if (*data_index < limit) {
1113       strings::StrAppend(result, "[");
1114       flag = true;
1115     }
1116     // As for each element, print the sub-dim.
1117     PrintOneDim(dim_index + 1, shape, limit, shape_size, data, data_index,
1118                 result);
1119     if (*data_index < limit || flag) {
1120       strings::StrAppend(result, "]");
1121       flag = false;
1122     }
1123   }
1124 }
1125 
1126 // Appends the spacing between elements for a given dim onto a result string
PrintDimSpacing(int dim_index,int num_dims,string * result)1127 void PrintDimSpacing(int dim_index, int num_dims, string* result) {
1128   if (dim_index == num_dims - 1) {
1129     strings::StrAppend(result, " ");
1130     return;
1131   }
1132   for (int j = 0; j < num_dims - dim_index - 1; j++) {
1133     strings::StrAppend(result, "\n");
1134   }
1135   for (int j = 0; j <= dim_index; j++) {
1136     strings::StrAppend(result, " ");
1137   }
1138 }
1139 
1140 // Print from left dim to right dim recursively.
1141 template <typename T>
PrintOneDimV2(int dim_index,const gtl::InlinedVector<int64,4> & shape,int64_t num_elts_at_ends,int num_dims,const T * data,int64_t data_index,string * result)1142 void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
1143                    int64_t num_elts_at_ends, int num_dims, const T* data,
1144                    int64_t data_index, string* result) {
1145   // We have recursed beyond all the dimensions into a single element
1146   // of the tensor.
1147   if (dim_index == num_dims) {
1148     strings::StrAppend(result, PrintOneElement(data[data_index], true));
1149     return;
1150   }
1151 
1152   strings::StrAppend(result, "[");
1153   int64_t element_count = shape[dim_index];
1154   int64_t start_of_end =
1155       std::max(num_elts_at_ends, element_count - num_elts_at_ends);
1156 
1157   // Loop every element of one dim.
1158   int64_t elements_per_iter = 1;
1159   for (int i = dim_index + 1; i < num_dims; i++) {
1160     elements_per_iter *= shape[i];
1161   }
1162   for (int64_t i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
1163     if (i > 0) {
1164       PrintDimSpacing(dim_index, num_dims, result);
1165     }
1166 
1167     // As for each element, print the sub-dim.
1168     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1169                   data_index + elements_per_iter * i, result);
1170   }
1171   if (element_count > 2 * num_elts_at_ends) {
1172     PrintDimSpacing(dim_index, num_dims, result);
1173     strings::StrAppend(result, "...");
1174   }
1175   for (int64_t i = start_of_end; i < element_count; i++) {
1176     // As for each element, print the sub-dim.
1177     PrintDimSpacing(dim_index, num_dims, result);
1178     PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
1179                   data_index + elements_per_iter * i, result);
1180   }
1181 
1182   strings::StrAppend(result, "]");
1183 }
1184 
1185 template <typename T>
SummarizeArray(int64_t limit,int64_t num_elts,const TensorShape & tensor_shape,const char * data,const bool print_v2)1186 string SummarizeArray(int64_t limit, int64_t num_elts,
1187                       const TensorShape& tensor_shape, const char* data,
1188                       const bool print_v2) {
1189   string ret;
1190   const T* array = reinterpret_cast<const T*>(data);
1191 
1192   const gtl::InlinedVector<int64_t, 4> shape = tensor_shape.dim_sizes();
1193   if (shape.empty()) {
1194     for (int64_t i = 0; i < limit; ++i) {
1195       if (i > 0) strings::StrAppend(&ret, " ");
1196       strings::StrAppend(&ret, PrintOneElement(array[i], print_v2));
1197     }
1198     if (num_elts > limit) strings::StrAppend(&ret, "...");
1199     return ret;
1200   }
1201   if (print_v2) {
1202     const int num_dims = tensor_shape.dims();
1203     PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
1204   } else {
1205     int64_t data_index = 0;
1206     const int shape_size = tensor_shape.dims();
1207     PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
1208 
1209     if (num_elts > limit) strings::StrAppend(&ret, "...");
1210   }
1211 
1212   return ret;
1213 }
1214 }  // namespace
1215 
SummarizeValue(int64_t max_entries,bool print_v2) const1216 string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const {
1217   const int64_t num_elts = NumElements();
1218   if (max_entries < 0) {
1219     max_entries = num_elts;
1220   }
1221   size_t limit = std::min(max_entries, num_elts);
1222   if ((limit > 0) && (buf_ == nullptr)) {
1223     return strings::StrCat("uninitialized Tensor of ", num_elts,
1224                            " elements of type ", dtype());
1225   }
1226   const char* data = limit > 0 ? tensor_data().data() : nullptr;
1227   switch (dtype()) {
1228     case DT_BFLOAT16:
1229       return SummarizeArray<bfloat16>(limit, num_elts, shape_, data, print_v2);
1230       break;
1231     case DT_HALF:
1232       return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
1233                                          print_v2);
1234       break;
1235     case DT_FLOAT:
1236       return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
1237       break;
1238     case DT_DOUBLE:
1239       return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
1240       break;
1241     case DT_UINT32:
1242       return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
1243       break;
1244     case DT_INT32:
1245       return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
1246       break;
1247     case DT_UINT8:
1248     case DT_QUINT8:
1249       return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
1250       break;
1251     case DT_UINT16:
1252     case DT_QUINT16:
1253       return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
1254       break;
1255     case DT_INT16:
1256     case DT_QINT16:
1257       return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
1258       break;
1259     case DT_INT8:
1260     case DT_QINT8:
1261       return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
1262       break;
1263     case DT_UINT64:
1264       return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
1265       break;
1266     case DT_INT64:
1267       return SummarizeArray<int64_t>(limit, num_elts, shape_, data, print_v2);
1268       break;
1269     case DT_BOOL:
1270       // TODO(tucker): Is it better to emit "True False..."?  This
1271       // will emit "1 0..." which is more compact.
1272       return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
1273       break;
1274     case DT_STRING:
1275       return SummarizeArray<tstring>(limit, num_elts, shape_, data, print_v2);
1276       break;
1277     default: {
1278       // All irregular cases
1279       string ret;
1280       if (print_v2 && (dims() > 0)) {
1281         strings::StrAppend(&ret, "[");
1282       }
1283       // TODO(irving): Don't call flat every time around this
1284       // loop.
1285       for (size_t i = 0; i < limit; ++i) {
1286         if (i > 0) strings::StrAppend(&ret, " ");
1287         switch (dtype()) {
1288           case DT_VARIANT: {
1289             const Variant& v = flat<Variant>()(i);
1290             strings::StrAppend(&ret, "<", v.SummarizeValue(), ">");
1291           } break;
1292           case DT_RESOURCE: {
1293             const ResourceHandle& r = flat<ResourceHandle>()(i);
1294             strings::StrAppend(&ret, "<", r.SummarizeValue(), ">");
1295           } break;
1296           default:
1297             // TODO(zhifengc, josh11b): Pretty-print other types (bool,
1298             // complex64, quantized).
1299             strings::StrAppend(&ret, "?");
1300         }
1301       }
1302       if (max_entries < num_elts) strings::StrAppend(&ret, "...");
1303       if (print_v2 && (dims() > 0)) {
1304         strings::StrAppend(&ret, "]");
1305       }
1306       return ret;
1307     }
1308   }
1309 }
1310 
tensor_data() const1311 StringPiece Tensor::tensor_data() const {
1312   if (buf_ == nullptr) return StringPiece();  // Don't die for empty tensors
1313   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
1314 }
1315 
data() const1316 void* Tensor::data() const {
1317   if (buf_ == nullptr) return nullptr;  // Don't die for empty tensors
1318   return static_cast<void*>(buf_->data());
1319 }
1320 
SharesBufferWith(const Tensor & b) const1321 bool Tensor::SharesBufferWith(const Tensor& b) const {
1322   return buf_ != nullptr && b.buf_ != nullptr &&
1323          buf_->root_buffer() == b.buf_->root_buffer();
1324 }
1325 
DebugString(int num_values) const1326 string Tensor::DebugString(int num_values) const {
1327   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1328                          " shape: ", shape().DebugString(),
1329                          " values: ", SummarizeValue(num_values), ">");
1330 }
1331 
DeviceSafeDebugString() const1332 string Tensor::DeviceSafeDebugString() const {
1333   return strings::StrCat("Tensor<type: ", DataTypeString(dtype()),
1334                          " shape: ", shape().DebugString(), ">");
1335 }
1336 
FillDescription(TensorDescription * description) const1337 void Tensor::FillDescription(TensorDescription* description) const {
1338   description->set_dtype(dtype());
1339   shape().AsProto(description->mutable_shape());
1340   if (buf_ != nullptr && buf_->data() != nullptr) {
1341     buf_->FillAllocationDescription(
1342         description->mutable_allocation_description());
1343   }
1344 }
1345 
ComputeFlatInnerDims(gtl::ArraySlice<int64_t> orig,int64_t num_out_dims)1346 gtl::InlinedVector<int64_t, 4> Tensor::ComputeFlatInnerDims(
1347     gtl::ArraySlice<int64_t> orig, int64_t num_out_dims) {
1348   gtl::InlinedVector<int64_t, 4> out_dims(num_out_dims, 0);
1349   int64_t offset = orig.size() - num_out_dims;
1350   for (int64_t out_dim = num_out_dims - 1; out_dim >= 0; --out_dim) {
1351     const int64_t in_dim = out_dim + offset;
1352     out_dims[out_dim] = in_dim < 0 ? 1 : orig[in_dim];
1353   }
1354   for (int64_t in_dim = 0; in_dim < offset; ++in_dim) {
1355     out_dims[0] *= orig[in_dim];
1356   }
1357   return out_dims;
1358 }
1359 
ComputeFlatOuterDims(gtl::ArraySlice<int64_t> orig,int64_t num_out_dims)1360 gtl::InlinedVector<int64_t, 4> Tensor::ComputeFlatOuterDims(
1361     gtl::ArraySlice<int64_t> orig, int64_t num_out_dims) {
1362   gtl::InlinedVector<int64_t, 4> out_dims(num_out_dims, 0);
1363   for (int64_t out_dim = 0; out_dim <= num_out_dims - 1; ++out_dim) {
1364     out_dims[out_dim] = out_dim >= orig.size() ? 1 : orig[out_dim];
1365   }
1366   for (int64_t in_dim = num_out_dims; in_dim < orig.size(); ++in_dim) {
1367     out_dims[num_out_dims - 1] *= orig[in_dim];
1368   }
1369   return out_dims;
1370 }
1371 
1372 }  // namespace tensorflow
1373