xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/variant.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
18 
19 #include <functional>
20 #include <iostream>
21 #include <memory>
22 #include <type_traits>
23 #include <unordered_map>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/framework/type_index.h"
28 #include "tensorflow/core/framework/variant_tensor_data.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/strcat.h"
31 
32 namespace tensorflow {
33 
34 template <typename T>
35 std::string TypeNameVariant(const T& value);
36 
37 template <typename T>
38 std::string DebugStringVariant(const T& value);
39 
40 // Allows for specializations of Variant Decoding.  `data` may be modified in
41 // the process of decoding to `value`.
42 template <typename T>
43 bool DecodeVariant(VariantTensorData* data, T* value);
44 
45 template <typename T>
46 bool DecodeVariant(std::string* buf, T* value);
47 
48 template <typename T>
49 void EncodeVariant(const T& value, VariantTensorData* data);
50 
51 template <typename T>
52 void EncodeVariant(const T& value, std::string* buf);
53 
54 // This is an implementation of a type-erased container that can store an
55 // object of any type. The implementation is very similar to std::any, but has
56 // restrictions on the types of objects that can be stored, and eschews some of
57 // the fancier constructors available for std::any. An object of
58 // tensorflow::Variant is intended to be used as the value that will be stored
59 // in a tensorflow::Tensor object when its type is DT_VARIANT.
60 //
61 // tensorflow::Variant can store an object of a class that satisfies the
62 // following constraints:
63 //
64 // * The class is CopyConstructible.
65 // * The class has a default constructor.
66 // * It's either a protocol buffer, a tensorflow::Tensor, or defines the
67 // following functions:
68 //
69 //   string TypeName() const;
70 //   void Encode(VariantTensorData* data) const;
71 //   bool Decode(VariantTensorData data);
72 //
73 // Simple POD types can elide the Encode/Decode functions, they are provided by
74 // helper methods.
75 // Here are some typical usage patterns:
76 //
77 //   Variant x = 10;
78 //   EXPECT_EQ(*x.get<int>(), 10);
79 //
80 //   Tensor t(DT_FLOAT, TensorShape({}));
81 //   t.flat<float>()(0) = 42.0f;
82 //   Variant x = t;
83 //   EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
84 //
85 // Accessing the stored object:
86 //
87 // The get<T> function is the main mechanism to access the object
88 // stored in the container. It is type-safe, that is, calling
89 // get<T> when the stored object's type is not T, returns a
90 // nullptr. A raw pointer to the stored object can be obtained by calling
91 // get<void>().
92 //
93 // Serializing/deserializing Variant object:
94 //
95 // The Variant class delegates serializing and deserializing operations to the
96 // contained object. Helper functions to do these operations are provided for
97 // POD data types, tensorflow::Tensor, and protocol buffer objects. However,
98 // other classes have to provide Encode/Decode functions to handle
99 // serialization.
100 //
101 // Objects stored in a Variant object often contain references to other
102 // tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
103 // To efficiently support those use cases, a structure is imposed on the
104 // serialization format. Namely, classes should serialize their contents into a
105 // VariantTensorData object:
106 //
107 //   struct VariantTensorData {
108 //     string type_name;
109 //     string metadata;
110 //     std::vector<Tensor> tensors;
111 //   };
112 //
113 // Objects with references to other Tensors can simply store those tensors in
114 // the `tensors` field, and serialize other metadata content in to the
115 // `metadata` field.
116 //
117 // Serialization example:
118 //
119 //   Foo f = Foo {...};
120 //   Variant x = f;
121 //   string serialized_f;
122 //   x.Encode(&serialized_f);
123 //
124 //   Variant y = Foo(); // default constructed Foo.
125 //   y.Decode(std::move(serialized_f));
126 //   EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
127 //
128 //
129 // A Variant storing serialized Variant data (a value of type
130 // VariantTensorDataProto) has different behavior from a standard Variant.
131 // Namely, its TypeName matches the TypeName of the original Variant;
132 // and its non-const get method performs lazy deserialization.
133 //
134 // Decode and copy example:
135 //
136 //   Foo f = Foo {...};
137 //   Variant x = f;
138 //
139 //   VariantTensorData serialized_data_f;
140 //   VariantTensorDataProto serialized_proto_f;
141 //   x.Encode(&serialized_data_f);
142 //   serialized_data_f.ToProto(&serialized_proto_f);
143 //
144 //   Variant y_type_unknown = serialized_proto_f;  // Store serialized Variant.
145 //
146 //   EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName());  // Looks like Foo.
147 //   EXPECT_EQ(TypeIndex::Make<VariantTensorDataProto>(),
148 //             y_type_unknown.TypeId());
149 //
150 class Variant {
151  public:
152   // Constructs a Variant holding no value (aka `is_empty()`).
153   //
154   // This is done by pointing at nullptr via the heap value.
Variant()155   Variant() noexcept : heap_value_(/*pointer=*/nullptr), is_inline_(false) {}
156 
157   ~Variant();
158 
159   Variant(const Variant& other);
160   Variant(Variant&& other) noexcept;
161 
162   // Make sure that the type is CopyConstructible and not a
163   // tensorflow::Variant object itself. We want the copy constructor to be
164   // chosen for the tensorflow::Variant case.
165   template <typename T, typename VT = typename std::decay<T>::type,
166             typename std::enable_if<!std::is_same<Variant, VT>::value &&
167                                         std::is_move_constructible<VT>::value,
168                                     void>::type* = nullptr>
169   Variant(T&& value);
170 
171   template <typename T, typename VT = typename std::decay<T>::type,
172             typename std::enable_if<!std::is_same<Variant, VT>::value &&
173                                         std::is_copy_constructible<VT>::value,
174                                     void>::type* = nullptr>
175   Variant(const T& value);
176 
177   template <typename T, typename VT = typename std::decay<T>::type,
178             typename std::enable_if<!std::is_same<Variant, VT>::value &&
179                                         std::is_copy_constructible<VT>::value,
180                                     void>::type* = nullptr>
181   Variant& operator=(const T& value);
182 
183   template <typename T, typename VT = typename std::decay<T>::type,
184             typename std::enable_if<!std::is_same<Variant, VT>::value &&
185                                         std::is_move_constructible<VT>::value,
186                                     void>::type* = nullptr>
187   Variant& operator=(T&& value);
188 
189   Variant& operator=(const Variant& rhs) {
190     if (&rhs == this) return *this;
191     Variant(rhs).swap(*this);
192     return *this;
193   }
194 
195   Variant& operator=(Variant&& rhs) noexcept {
196     if (&rhs == this) return *this;
197     Variant(std::move(rhs)).swap(*this);
198     return *this;
199   }
200 
201   // Constructs a value of type T with the given args in-place in this Variant.
202   // Returns a reference to the newly constructed value.
203   // The signature is based on std::variant<Types...>::emplace() in C++17.
204   template <typename T, class... Args>
emplace(Args &&...args)205   T& emplace(Args&&... args) {
206     ResetMemory();
207     is_inline_ = CanInlineType<T>();
208     if (is_inline_) {
209       new (&inline_value_)
210           InlineValue(InlineValue::Tag<T>{}, std::forward<Args>(args)...);
211       return static_cast<Variant::Value<T>*>(inline_value_.AsValueInterface())
212           ->value;
213     } else {
214       new (&heap_value_) HeapValue(
215           absl::make_unique<Value<T>>(InPlace(), std::forward<Args>(args)...));
216       return static_cast<Variant::Value<T>*>(heap_value_.get())->value;
217     }
218   }
219 
is_empty()220   bool is_empty() const { return GetValue() == nullptr; }
221 
222   void clear() noexcept;
223 
224   void swap(Variant& other) noexcept;
225 
226   // Note, unlike TypeName(), TypeId() does not return the TypeIndex
227   // of the original type when a TensorValueDataProto is stored as the
228   // value.  In this case, it returns the TypeIndex of TensorValueDataProto.
TypeId()229   TypeIndex TypeId() const {
230     const TypeIndex VoidTypeIndex = TypeIndex::Make<void>();
231     if (is_empty()) {
232       return VoidTypeIndex;
233     }
234     return GetValue()->TypeId();
235   }
236 
DebugString()237   std::string DebugString() const {
238     return strings::StrCat("Variant<type: ", TypeName(),
239                            " value: ", SummarizeValue(), ">");
240   }
241 
SummarizeValue()242   std::string SummarizeValue() const {
243     return is_empty() ? "[empty]" : GetValue()->DebugString();
244   }
245 
246   // Returns a pointer to the stored value if it is type T, or nullptr
247   // otherwise.
248   template <typename T>
get()249   T* get() {
250     const TypeIndex TTypeIndex = TypeIndex::Make<T>();
251     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
252     return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
253   }
254 
255   // Returns a pointer to the stored value if it is type T, or nullptr
256   // otherwise.
257   template <typename T>
get()258   const T* get() const {
259     const TypeIndex TTypeIndex = TypeIndex::Make<T>();
260     if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
261     return std::addressof(
262         static_cast<const Variant::Value<T>*>(GetValue())->value);
263   }
264 
265   // Returns TypeNameVariant(value).
266   //
267   // In the special case that a serialized Variant is stored (value
268   // is a VariantTensorDataProto), returns value.TypeName(), the
269   // TypeName field stored in the VariantTensorDataProto buffer.
TypeName()270   std::string TypeName() const {
271     if (is_empty()) {
272       return "";
273     }
274     return GetValue()->TypeName();
275   }
276 
277   // Serialize the contents of the stored object into `data`.
Encode(VariantTensorData * data)278   void Encode(VariantTensorData* data) const {
279     if (!is_empty()) {
280       GetValue()->Encode(data);
281     }
282   }
283 
284   // Deserialize `data` and update the stored object.
285   bool Decode(VariantTensorData data);
286 
287   // Helper methods to directly serialize/deserialize from strings.
Encode(std::string * buf)288   void Encode(std::string* buf) const {
289     if (!is_empty()) {
290       GetValue()->Encode(buf);
291     }
292   }
Decode(std::string buf)293   bool Decode(std::string buf) {
294     if (!is_empty()) {
295       return GetValue()->Decode(std::move(buf));
296     }
297     return true;
298   }
299 
300   template <typename VT>
CanInlineType()301   static constexpr bool CanInlineType() {
302     return ((sizeof(Value<VT>) <= InlineValue::kMaxValueSize) &&
303             (alignof(Value<VT>) <= kMaxInlineValueAlignSize));
304   }
305 
306  private:
307   struct in_place_t {};
InPlace()308   static constexpr in_place_t InPlace() { return in_place_t{}; }
309 
310   struct ValueInterface {
311     virtual ~ValueInterface() = default;
312     virtual TypeIndex TypeId() const = 0;
313     virtual void* RawPtr() = 0;
314     virtual const void* RawPtr() const = 0;
315     virtual std::unique_ptr<ValueInterface> Clone() const = 0;
316     virtual void CloneInto(ValueInterface* memory) const = 0;
317     virtual void MoveAssign(ValueInterface* memory) = 0;
318     virtual void MoveInto(ValueInterface* memory) = 0;
319     virtual std::string TypeName() const = 0;
320     virtual std::string DebugString() const = 0;
321     virtual void Encode(VariantTensorData* data) const = 0;
322     virtual bool Decode(VariantTensorData data) = 0;
323     virtual void Encode(std::string* buf) const = 0;
324     virtual bool Decode(std::string data) = 0;
325   };
326 
327   template <typename T>
328   struct Value final : ValueInterface {
329     template <class... Args>
Valuefinal330     explicit Value(in_place_t /*tag*/, Args&&... args)
331         : value(std::forward<Args>(args)...) {}
332 
333     // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily
334     // build `alignof(Variant<void*>)`.
335     ~Value() final = default;
336 
TypeIdfinal337     TypeIndex TypeId() const final {
338       const TypeIndex value_type_index =
339           TypeIndex::Make<typename std::decay<T>::type>();
340       return value_type_index;
341     }
342 
RawPtrfinal343     void* RawPtr() final { return &value; }
344 
RawPtrfinal345     const void* RawPtr() const final { return &value; }
346 
Clonefinal347     std::unique_ptr<ValueInterface> Clone() const final {
348       return absl::make_unique<Value>(InPlace(), value);
349     }
350 
MoveAssignfinal351     void MoveAssign(ValueInterface* memory) final {
352       CHECK(TypeId() == memory->TypeId())
353           << TypeId().name() << " vs. " << memory->TypeId().name();
354       static_cast<Value*>(memory)->value = std::move(value);
355     }
356 
CloneIntofinal357     void CloneInto(ValueInterface* memory) const final {
358       new (memory) Value(InPlace(), value);
359     }
360 
MoveIntofinal361     void MoveInto(ValueInterface* memory) final {
362       new (memory) Value(InPlace(), std::move(value));
363     }
364 
TypeNamefinal365     std::string TypeName() const final { return TypeNameVariant(value); }
366 
DebugStringfinal367     std::string DebugString() const final { return DebugStringVariant(value); }
368 
Encodefinal369     void Encode(VariantTensorData* data) const final {
370       EncodeVariant(value, data);
371     }
372 
Decodefinal373     bool Decode(VariantTensorData data) final {
374       return DecodeVariant(&data, &value);
375     }
376 
Encodefinal377     void Encode(std::string* buf) const final { EncodeVariant(value, buf); }
378 
Decodefinal379     bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); }
380 
381     T value;
382   };
383   static constexpr int kMaxInlineValueAlignSize = alignof(Value<void*>);
384 
385   using HeapValue = std::unique_ptr<ValueInterface>;
386 
387   struct InlineValue {
388     // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit
389     // into the aligned space of a TensorBuffer.
390     static constexpr int kMaxValueSize = (64 - /*some extra padding=*/8);
391 
392     typedef char ValueDataArray[kMaxValueSize];
393     alignas(kMaxInlineValueAlignSize) ValueDataArray value_data;
394 
395     // Tag is used for deducing the right type when constructing a Value in
396     // place.
397     template <typename VT>
398     struct Tag {};
399 
400     template <typename VT, class... Args>
InlineValueInlineValue401     explicit InlineValue(Tag<VT> /*tag*/, Args&&... args) noexcept {
402       Value<VT>* inline_value_data = reinterpret_cast<Value<VT>*>(value_data);
403       new (inline_value_data) Value<VT>(InPlace(), std::forward<Args>(args)...);
404     }
405 
InlineValueInlineValue406     InlineValue(const InlineValue& other) noexcept {
407       other.AsValueInterface()->CloneInto(AsValueInterface());
408     }
409 
InlineValueInlineValue410     InlineValue(InlineValue&& other) noexcept {
411       other.AsValueInterface()->MoveInto(AsValueInterface());
412     }
413 
ResetMemoryInlineValue414     void ResetMemory() { AsValueInterface()->~ValueInterface(); }
415 
416     InlineValue& operator=(const InlineValue& other) {
417       if (&other == this) return *this;
418       ResetMemory();
419       other.AsValueInterface()->CloneInto(AsValueInterface());
420       return *this;
421     }
422 
423     InlineValue& operator=(InlineValue&& other) {
424       if (&other == this) return *this;
425       if (AsValueInterface()->TypeId() == other.AsValueInterface()->TypeId()) {
426         other.AsValueInterface()->MoveAssign(AsValueInterface());
427       } else {
428         ResetMemory();
429         other.AsValueInterface()->MoveInto(AsValueInterface());
430       }
431       return *this;
432     }
433 
AsValueInterfaceInlineValue434     ValueInterface* AsValueInterface() {
435       return reinterpret_cast<ValueInterface*>(value_data);
436     }
437 
AsValueInterfaceInlineValue438     const ValueInterface* AsValueInterface() const {
439       return reinterpret_cast<const ValueInterface*>(value_data);
440     }
441 
~InlineValueInlineValue442     ~InlineValue() { ResetMemory(); }
443   };
444 
445   union {
446     HeapValue heap_value_;
447     InlineValue inline_value_;
448   };
449   // is_inline_ provides discrimination between which member of the prior union
450   // is currently within it's lifetime. To switch from one member to the other,
451   // the destructor must be called on the currently alive member before calling
452   // the constructor on the other member. In effect, a member is expected to be
453   // live at any given time and that member is tracked via this boolean.
454   bool is_inline_;
455 
IsInlineValue()456   bool IsInlineValue() const { return is_inline_; }
457 
458   // ResetMemory causes the destructor of the currently active member of the
459   // union to be run. This must be follwed with a placement new call on the
460   // member whose lifetime is to start. Additionally, is_inline_ needs to be set
461   // accordingly. ResetAndSetInline and ResetAndSetHeap are simple helper
462   // functions for performing the actions that are required to follow.
ResetMemory()463   void ResetMemory() {
464     if (IsInlineValue()) {
465       inline_value_.~InlineValue();
466     } else {
467       heap_value_.~HeapValue();
468     }
469   }
470 
471   // ResetAndSetInline clears the current state and then constructs a new value
472   // inline with the provided arguments.
473   template <typename... Args>
ResetAndSetInline(Args &&...args)474   void ResetAndSetInline(Args&&... args) noexcept {
475     ResetMemory();
476     new (&inline_value_) InlineValue(std::forward<Args>(args)...);
477     is_inline_ = true;
478   }
479 
480   // ResetAndSetHeap clears the current state then constructs a new value on the
481   // heap with the provided arguments.
482   template <typename... Args>
ResetAndSetHeap(Args &&...args)483   void ResetAndSetHeap(Args&&... args) noexcept {
484     ResetMemory();
485     new (&heap_value_) HeapValue(std::forward<Args>(args)...);
486     is_inline_ = false;
487   }
488 
GetValue()489   ValueInterface* GetValue() {
490     if (IsInlineValue()) {
491       return inline_value_.AsValueInterface();
492     } else {
493       return heap_value_.get();
494     }
495   }
496 
GetValue()497   const ValueInterface* GetValue() const {
498     if (IsInlineValue()) {
499       return inline_value_.AsValueInterface();
500     } else {
501       return heap_value_.get();
502     }
503   }
504 
505   // PRECONDITION: Called on construction or ResetMemory() has been called
506   // before this method.
507   template <typename VT, typename T>
InsertValue(T && value)508   void InsertValue(T&& value) {
509     if (IsInlineValue()) {
510       new (&inline_value_)
511           InlineValue(InlineValue::Tag<VT>{}, std::forward<T>(value));
512     } else {
513       new (&heap_value_) HeapValue(
514           absl::make_unique<Value<VT>>(InPlace(), std::forward<T>(value)));
515     }
516   }
517 };
518 
519 // Make sure that a Variant object can reside in a 64-byte aligned Tensor
520 // buffer.
521 static_assert(sizeof(Variant) <= 64,
522               "Expected internal representation to be 64 bytes.");
523 
Variant(const Variant & other)524 inline Variant::Variant(const Variant& other)
525     : is_inline_(other.IsInlineValue()) {
526   if (IsInlineValue()) {
527     new (&inline_value_) InlineValue(other.inline_value_);
528   } else {
529     new (&heap_value_)
530         HeapValue(other.heap_value_ ? other.heap_value_->Clone() : nullptr);
531   }
532 }
533 
Variant(Variant && other)534 inline Variant::Variant(Variant&& other) noexcept
535     : is_inline_(other.IsInlineValue()) {
536   if (IsInlineValue()) {
537     new (&inline_value_) InlineValue(std::move(other.inline_value_));
538   } else {
539     new (&heap_value_) HeapValue(std::move(other.heap_value_));
540   }
541 }
542 
543 template <typename T, typename VT,
544           typename std::enable_if<!std::is_same<Variant, VT>::value &&
545                                       std::is_move_constructible<VT>::value,
546                                   void>::type*>
Variant(T && value)547 inline Variant::Variant(T&& value) : is_inline_(CanInlineType<VT>()) {
548   InsertValue<VT>(std::forward<T>(value));
549 }
550 
551 template <typename T, typename VT,
552           typename std::enable_if<!std::is_same<Variant, VT>::value &&
553                                       std::is_copy_constructible<VT>::value,
554                                   void>::type*>
Variant(const T & value)555 inline Variant::Variant(const T& value) : is_inline_(CanInlineType<VT>()) {
556   InsertValue<VT>(value);
557 }
558 
559 template <typename T, typename VT,
560           typename std::enable_if<!std::is_same<Variant, VT>::value &&
561                                       std::is_move_constructible<VT>::value,
562                                   void>::type*>
563 inline Variant& Variant::operator=(T&& value) {
564   ResetMemory();
565   is_inline_ = CanInlineType<VT>();
566   InsertValue<VT>(std::forward<T>(value));
567   return *this;
568 }
569 
570 template <typename T, typename VT,
571           typename std::enable_if<!std::is_same<Variant, VT>::value &&
572                                       std::is_copy_constructible<VT>::value,
573                                   void>::type*>
574 inline Variant& Variant::operator=(const T& value) {
575   ResetMemory();
576   is_inline_ = CanInlineType<VT>();
577   InsertValue<VT>(value);
578   return *this;
579 }
580 
clear()581 inline void Variant::clear() noexcept {
582   // We set the internal unique_ptr to nullptr so that we preserve the
583   // invariant that one of the two states must be set at all times. nullptr
584   // indicates that the variant is empty.
585   ResetAndSetHeap(/*pointer=*/nullptr);
586 }
587 
swap(Variant & other)588 inline void Variant::swap(Variant& other) noexcept {
589   if (is_empty()) {
590     if (other.IsInlineValue()) {
591       ResetAndSetInline(std::move(other.inline_value_));
592     } else {
593       ResetAndSetHeap(std::move(other.heap_value_));
594     }
595     other.clear();
596   } else if (other.is_empty()) {
597     if (IsInlineValue()) {
598       other.ResetAndSetInline(std::move(inline_value_));
599     } else {
600       other.ResetAndSetHeap(std::move(heap_value_));
601     }
602     clear();
603   } else {  // Both Variants have values.
604     if (other.IsInlineValue() && IsInlineValue()) {
605       std::swap(inline_value_, other.inline_value_);
606     } else if (!other.IsInlineValue() && !IsInlineValue()) {
607       std::swap(heap_value_, other.heap_value_);
608     } else if (other.IsInlineValue() && !IsInlineValue()) {
609       HeapValue v = std::move(heap_value_);
610       ResetAndSetInline(std::move(other.inline_value_));
611       other.ResetAndSetHeap(std::move(v));
612     } else {  // !other.IsInlineValue() && IsInlineValue()
613       HeapValue v = std::move(other.heap_value_);
614       other.ResetAndSetInline(std::move(inline_value_));
615       ResetAndSetHeap(std::move(v));
616     }
617   }
618 }
619 
620 template <>
621 void* Variant::get();
622 
623 template <>
624 const void* Variant::get() const;
625 
626 }  // end namespace tensorflow
627 
628 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
629