xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/experimental/libtf/value.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// @file value.h
16 /// @brief The TaggedValue struct that supports Python-like behavior in C++.
17 ///
18 /// The TaggedValue struct implements a tagged union data structure
19 /// (https://en.wikipedia.org/wiki/Tagged_union) in the TensorFlow C++ API. It
20 /// contains a `Type` enum (sometimes referred to as a "tag")
21 /// and a `Data` union for holding values.
22 
23 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_
24 #define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_
25 
26 #include <iostream>
27 #include <memory>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/container/flat_hash_map.h"
32 #include "tensorflow/c/eager/abstract_tensor_handle.h"
33 #include "tensorflow/core/platform/intrusive_ptr.h"
34 #include "tensorflow/core/platform/statusor.h"
35 
36 // TODO(b/195578409): Move all value objects into `impl`. Currently only values
37 // that do not reference TaggedValue are there.
38 #include "tensorflow/cc/experimental/libtf/impl/none.h"
39 #include "tensorflow/cc/experimental/libtf/impl/scalars.h"
40 #include "tensorflow/cc/experimental/libtf/impl/string.h"
41 #include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h"
42 
43 namespace tf {
44 namespace libtf {
45 namespace impl {
46 // Necessary forward declares.
47 class TaggedValue;
48 class Tuple;
49 template <class T>
50 // TODO(ccrusius): Use absl::Hash specializations instead.
51 class TaggedValueHash;
52 using List = std::vector<TaggedValue>;
53 using ListPtr = std::shared_ptr<List>;
54 using Dict =
55     absl::flat_hash_map<TaggedValue, TaggedValue, TaggedValueHash<TaggedValue>>;
56 using DictPtr = std::shared_ptr<Dict>;
57 using TuplePtr = std::shared_ptr<Tuple>;
58 using Func =
59     std::function<tensorflow::StatusOr<TaggedValue>(TaggedValue, TaggedValue)>;
60 // A capsule holds a pointer and a destructor for the pointer (i.e. a generic
61 // shared_ptr to void with a custom deleter).
62 using Capsule = std::shared_ptr<void>;
63 using TaggedValueTensor =
64     tensorflow::core::IntrusivePtr<tensorflow::AbstractTensorHandle>;
65 
66 // Declare hash types so they can be instantiated below.
67 
68 /// @brief TaggedValue hashing infrastructure, which uses absl::hash.
69 ///
70 /// Hashable TaggedValues overload `AbslHashValue`. Non-hashable structures
71 /// return 0.
72 template <>
73 struct TaggedValueHash<TaggedValue> {
74   size_t operator()(const TaggedValue& v) const;
75 };
76 
77 /// @brief Hash implementation for TaggedValue Tuples.
78 template <>
79 struct TaggedValueHash<Tuple> {
80   size_t operator()(const Tuple& t) const;
81 };
82 
83 /// @brief The basic `TaggedValue` tagged union type.
84 ///
85 /// A `TaggedValue` contains a `Type` (or "tag") as an enum and a `Value` union.
86 /// Values include tensors, primitive values, lists, tuples, and dictionaries.
87 /// In the future we might also want to have representation of python objects in
88 /// the form of PyObject*.
89 class TaggedValue final {
90  public:
91   /// @brief Enum that describes the possible types a `TaggedValue` can be.
92   ///
93   /// A `TaggedValue` must be one of the following types: NONE, INT64, FLOAT32,
94   /// STRING, FUNC, DICT, LIST, TUPLE, TENSOR, TENSOR_SPEC, CAPSULE.
95   enum Type {
96     NONE = 0,
97     INT64 = 1,
98     FLOAT32 = 2,
99     STRING = 3,
100     FUNC = 4,
101     DICT = 5,
102     LIST = 6,
103     TUPLE = 7,
104     TENSOR = 8,
105     TENSOR_SPEC = 9,
106     CAPSULE = 10,
107   };
108   TaggedValue() : type_(NONE), data_() {}
109 
110   /// Move assignment operator.
111   TaggedValue& operator=(TaggedValue&& v) {
112     destroy();
113     MoveIntoUnion(std::move(v));
114     return *this;
115   }
116   /// Move constructor.
117   TaggedValue(TaggedValue&& v) : type_(NONE) { MoveIntoUnion(std::move(v)); }
118   /// Copy constructor.
119   TaggedValue(const TaggedValue& v) : type_(NONE) { CopyIntoUnion(v); }
120   /// Copy assignment operator.
121   TaggedValue& operator=(const TaggedValue& v) {
122     destroy();
123     CopyIntoUnion(v);
124     return *this;
125   }
126   /// TaggedValue constructor for type TENSOR.
127   explicit TaggedValue(TaggedValueTensor tensor)
128       : type_(TENSOR), data_(std::move(tensor)) {}
129   /// TaggedValue constructor for type TENSOR_SPEC.
130   explicit TaggedValue(tensorflow::PartialTensorShape shape,
131                        tensorflow::DataType dtype)
132       : type_(TENSOR_SPEC), data_(shape, dtype) {}
133   /// TaggedValue constructor for type FUNC.
134   explicit TaggedValue(Func f32) : type_(FUNC), data_(f32) {}
135   /// TaggedValue constructor for type FLOAT32.
136   explicit TaggedValue(float f32) : type_(FLOAT32), data_(Float32(f32)) {}
137   /// TaggedValue constructor for type INT64.
138   explicit TaggedValue(int64_t i64) : type_(INT64), data_(Int64(i64)) {}
139   /// TaggedValue constructor for type FLOAT32.
140   explicit TaggedValue(Float32 f32) : type_(FLOAT32), data_(f32) {}
141   /// TaggedValue constructor for type INT64.
142   explicit TaggedValue(Int64 i64) : type_(INT64), data_(i64) {}
143   /// TaggedValue constructor for type STRING.
144   explicit TaggedValue(const char* s) : type_(STRING), data_(s) {}
145   /// Constructs a TaggedValue with type NONE.
146   static TaggedValue None() {
147     TaggedValue v;
148     v.type_ = NONE;
149     return v;
150   }
151   /// Constructs a TaggedValue with type LIST.
152   static TaggedValue List() {
153     TaggedValue v;
154     v.type_ = LIST;
155     using T = decltype(v.data_.list);
156     new (&v.data_.list) T(std::make_shared<T::element_type>());
157     return v;
158   }
159   /// Constructs a TaggedValue with type TUPLE.
160   static TaggedValue Tuple() {
161     TaggedValue v;
162     v.type_ = TUPLE;
163     using T = decltype(v.data_.tuple);
164     new (&v.data_.tuple) T(std::make_shared<T::element_type>());
165     return v;
166   }
167   /// Constructs a TaggedValue with type DICT.
168   static TaggedValue Dict() {
169     TaggedValue v;
170     v.type_ = DICT;
171     using T = decltype(v.data_.dict);
172     new (&v.data_.dict) T(std::make_shared<T::element_type>());
173     return v;
174   }
175   /// Constructs a TaggedValue with type TENSOR.
176   static TaggedValue Tensor(tensorflow::AbstractTensorHandle* raw_ptr) {
177     TaggedValue v;
178     v.type_ = TENSOR;
179     using T = decltype(v.data_.tensor);
180     new (&v.data_.tensor) T(raw_ptr, /*add_ref=*/false);
181     return v;
182   }
183 
184   /// Constructs a TaggedValue with type CAPSULE with a default destructor.
185   template <class T>
186   static TaggedValue Capsule(T* data) {
187     return Capsule(static_cast<void*>(data),
188                    [](void* x) { delete static_cast<T*>(x); });
189   }
190   /// Constructs a TaggedValue with type CAPSULE with a custom destructor.
191   static TaggedValue Capsule(void* data, void (*deleter)(void*)) {
192     TaggedValue v;
193     v.type_ = CAPSULE;
194     using T = decltype(v.data_.capsule);
195     new (&v.data_.capsule) T(data, deleter);
196     return v;
197   }
198   /// Destroys TaggedValue. Shared pointers in unions must be explicitly
199   /// deleted.
200   void destroy() {
201     if (type_ != NONE) {
202       // Explicitly run the destructor on the correct type.
203       visit<void>([](auto& x) {
204         using T = typename std::decay<decltype(x)>::type;
205         x.~T();
206       });
207       // Make the type None, whenever we destroy so we always have an
208       // initialized value.
209       type_ = NONE;
210     }
211   }
212   ~TaggedValue() { destroy(); }
213 
214   /// @brief Get the underlying value based on type.
215   ///
216   /// @tparam T The desired return type.
217   /// @return The unwrapped value. If this `TaggedValue` type does not currently
218   ///         contain a value of type `T`, the program terminates via a call to
219   ///         `assert`.
220   template <typename T>
221   T& get() {
222     assert(type_ == EnumValueOf<T>::value);
223     return UnionAccess<T>::unsafe_reference(*this);
224   }
225 
226   /// @brief Get the underlying value based on type.
227   ///
228   /// @tparam T The desired return type.
229   /// @return The unwrapped value. If this `TaggedValue` type does not currently
230   ///         contain a value of type `T`, the program terminates via a call to
231   ///         `assert`.
232   template <typename T>
233   const T& get() const {
234     assert(type_ == EnumValueOf<T>::value);
235     return UnionAccess<T>::unsafe_reference(*this);
236   }
237 
238   /// Retrieves underlying value from a TaggedValue with type INT64.
239   const Int64& i64() const { return get<impl::Int64>(); }
240 
241   /// Retrieves underlying value from a TaggedValue with type FLOAT32.
242   const Float32& f32() const { return get<impl::Float32>(); }
243 
244   /// Retrieves underlying value from a TaggedValue with type STRING.
245   const char* s() const { return get<impl::String>().str().c_str(); }
246 
247   /// Retrieves underlying value from a TaggedValue with type LIST.
248   impl::List& list() { return *get<impl::ListPtr>(); }
249   /// Retrieves underlying value from a TaggedValue with type LIST.
250   const impl::List& list() const { return *get<impl::ListPtr>(); }
251 
252   /// Retrieves underlying value from a TaggedValue with type TUPLE.
253   impl::Tuple& tuple() { return *get<impl::TuplePtr>(); }
254   /// Retrieves underlying value from TaggedValues with type TUPLE.
255   const impl::Tuple& tuple() const { return *get<impl::TuplePtr>(); }
256 
257   /// Retrieves underlying value from a TaggedValue with type DICT.
258   impl::Dict& dict() { return *get<impl::DictPtr>(); }
259   /// Retrieves underlying value from TaggedValues with type DICT.
260   const impl::Dict& dict() const { return *get<impl::DictPtr>(); }
261 
262   /// Retrieves underlying value from a TaggedValue with type FUNC.
263   impl::Func func() const { return get<impl::Func>(); }
264 
265   // TODO(danielellis): make const-only if possible, once the API allows for it
266   /// Retrieves underlying value from a TaggedValue with type TENSOR.
267   TaggedValueTensor& tensor() { return get<TaggedValueTensor>(); }
268   /// Retrieves underlying value from a TaggedValue with type TENSOR.
269   const TaggedValueTensor& tensor() const { return get<TaggedValueTensor>(); }
270 
271   /// Retrieves underlying value from a TaggedValue with type TENSOR_SPEC.
272   const TensorSpec& tensor_spec() const { return get<TensorSpec>(); }
273 
274   /// Retrieves underlying value from a TaggedValue with type CAPSULE.
275   void* capsule() const { return get<impl::Capsule>().get(); }
276 
277   /// Retrieves type of TaggedValue.
278   Type type() const { return type_; }
279 
280   /// @brief Implements equality operator for TaggedValue.
281   bool operator==(const TaggedValue& o) const {
282     if (type_ != o.type_) return false;
283     switch (type_) {
284       case LIST:
285         return data_.list == o.data_.list;
286         break;
287       case TUPLE:
288         return data_.tuple == o.data_.tuple;
289         break;
290       case DICT:
291         return data_.dict == o.data_.dict;
292         break;
293       case FUNC:
294         // TODO(b/187536093):  This is definitely wrong because the exact ptr of
295         // the function pointer is almost always different, because we hold
296         // it by value. Two tagged values that hold the same std::function
297         // will have different std::function ptrs. operator== is not defined
298         // for std::function's so we need a better solution here, or these
299         // are not comparable which seems bad.
300         return &data_.func == &o.data_.func;
301         break;
302       case FLOAT32:
303         return data_.f32 == o.data_.f32;
304         break;
305       case INT64:
306         return data_.i64 == o.data_.i64;
307         break;
308       case STRING:
309         return data_.s == o.data_.s;
310         break;
311       case TENSOR:
312         return data_.tensor == o.data_.tensor;
313       case TENSOR_SPEC:
314         return data_.tensor_spec == o.data_.tensor_spec;
315       case CAPSULE:
316         return data_.capsule.get() == o.data_.capsule.get();
317       case NONE:
318         return true;
319     }
320   }
321 
322   /// @brief Implements visitor pattern for doing type-based dispatch.
323   ///
324   /// @tparam R The desired return type.
325   /// @tparam Visitor The visitor class which has a callable operator.
326   /// @return The `visitor` called on the correct value.
327   template <class R, class Visitor>
328   R visit(Visitor visitor) {
329     switch (type_) {
330       case LIST:
331         return visitor(data_.list);
332       case TUPLE:
333         return visitor(data_.tuple);
334       case DICT:
335         return visitor(data_.dict);
336       case FUNC:
337         return visitor(data_.func);
338       case FLOAT32:
339         return visitor(data_.f32);
340       case INT64:
341         return visitor(data_.i64);
342       case STRING:
343         return visitor(data_.s);
344       case TENSOR:
345         return visitor(data_.tensor);
346       case TENSOR_SPEC:
347         return visitor(data_.tensor_spec);
348       case CAPSULE:
349         return visitor(data_.capsule);
350       case NONE:
351         return visitor(impl::None::GetInstance());
352     }
353   }
354 
355   /// @brief Implements visitor pattern for doing type-based dispatch.
356   ///
357   /// @tparam R The desired return type.
358   /// @tparam Visitor The visitor class which has a callable operator.
359   /// @return The `visitor` called on the correct value.
360   template <class R, class Visitor>
361   R visit(Visitor visitor) const {
362     switch (type_) {
363       case LIST:
364         return visitor(data_.list);
365       case TUPLE:
366         return visitor(data_.tuple);
367       case DICT:
368         return visitor(data_.dict);
369       case FUNC:
370         return visitor(data_.func);
371       case FLOAT32:
372         return visitor(data_.f32);
373       case INT64:
374         return visitor(data_.i64);
375       case STRING:
376         return visitor(data_.s);
377       case TENSOR:
378         return visitor(data_.tensor);
379       case TENSOR_SPEC:
380         return visitor(data_.tensor_spec);
381       case CAPSULE:
382         return visitor(data_.capsule);
383       case NONE:
384         return visitor(impl::None::GetInstance());
385     }
386   }
387 
388  private:
389   /// @brief A utility class for mapping C++ types to Type values.
390   template <typename T>
391   struct EnumValueOf;
392 
393   /// @brief A utility class for accessing the `Data` union members.
394   template <typename T>
395   struct UnionAccess;
396 
397   // Unsafe Move, because it assumes the union has already been destroyed
398   // or is new!
399   void MoveIntoUnion(TaggedValue&& v) {
400     assert(type_ == NONE);
401     type_ = v.type_;
402     if (type_ != NONE) {
403       visit<void>([&v](auto& left) -> void {
404         using T = typename std::decay<decltype(left)>::type;
405         new (&left) T(std::move(UnionAccess<T>::unsafe_reference(v)));
406       });
407     }
408     // Destroy the source r-value reference (making it None)
409     v.destroy();
410   }
411 
412   // Unsafe Move, because it assumes the union has already been destroyed
413   // or is new!
414   void CopyIntoUnion(const TaggedValue& v) {
415     assert(type_ == NONE);
416     type_ = v.type_;
417     if (type_ != NONE) {
418       visit<void>([&v](auto& left) -> void {
419         using T = typename std::decay<decltype(left)>::type;
420         new (&left) T(UnionAccess<T>::unsafe_reference(v));
421       });
422     }
423   }
424 
425   /// @brief The type of the TaggedValue, i.e. the "tag" of a tagged union.
426   ///
427   /// In principle this could be incorporated into the union
428   /// for pointer types and non-64bit values, but then int64 and float64 values
429   /// would need to be indirected.  This means that we are aiming for a total
430   /// data type size of <=16 bytes, comprised of one pointer (8 bytes) and
431   /// one type (<=8bytes).
432   Type type_;
433 
434   // we use an explicit union here because we want to avoid C++17's
435   // variant structures due to c++14 compatibility requirements.
436   // TODO(b/183980966): Compare against absl::variant.
437   union Data {
438     explicit Data() {}
439     explicit Data(Float32 f32) : f32(f32) {}
440     explicit Data(Int64 i64) : i64(i64) {}
441     explicit Data(const char* s) : s(String(s)) {}
442     explicit Data(Func fn) : func(fn) {}
443     explicit Data(TaggedValueTensor tensor_in) {
444       new (&tensor) TaggedValueTensor(std::move(tensor_in));
445     }
446     explicit Data(tensorflow::PartialTensorShape shape,
447                   tensorflow::DataType dtype)
448         : tensor_spec({shape, dtype}) {}
449     ~Data() {}
450     Float32 f32;
451     Int64 i64;
452     String s;
453     Func func;
454     // TODO(aselle): look at tensorflow thing
455     std::shared_ptr<impl::Dict> dict;
456     std::shared_ptr<impl::List> list;
457     std::shared_ptr<impl::Tuple> tuple;
458     impl::Capsule capsule;
459     TaggedValueTensor tensor;
460     TensorSpec tensor_spec;
461   } data_;
462   friend std::ostream& operator<<(std::ostream& o, const TaggedValue& v);
463   friend TaggedValueHash<TaggedValue>;
464 };
465 
466 #define TF_ENUM_VALUE_OF(TYPE, ENUM)      \
467   template <>                             \
468   struct TaggedValue::EnumValueOf<TYPE> { \
469     static constexpr Type value = ENUM;   \
470   };
471 
472 TF_ENUM_VALUE_OF(impl::Capsule, CAPSULE);
473 TF_ENUM_VALUE_OF(impl::Float32, FLOAT32);
474 TF_ENUM_VALUE_OF(impl::Int64, INT64);
475 TF_ENUM_VALUE_OF(impl::List, LIST);
476 TF_ENUM_VALUE_OF(impl::ListPtr, LIST);
477 TF_ENUM_VALUE_OF(impl::Tuple, TUPLE);
478 TF_ENUM_VALUE_OF(impl::TuplePtr, TUPLE);
479 TF_ENUM_VALUE_OF(impl::Dict, DICT);
480 TF_ENUM_VALUE_OF(impl::DictPtr, DICT);
481 TF_ENUM_VALUE_OF(impl::None, NONE);
482 TF_ENUM_VALUE_OF(impl::Func, FUNC);
483 TF_ENUM_VALUE_OF(impl::String, STRING);
484 TF_ENUM_VALUE_OF(impl::TaggedValueTensor, TENSOR);
485 TF_ENUM_VALUE_OF(impl::TensorSpec, TENSOR_SPEC);
486 #undef TF_ENUM_VALUE_OF
487 
488 #define TF_UNION_ACCESS_INSTANCE(TYPE, MEMBER)                               \
489   template <>                                                                \
490   struct TaggedValue::UnionAccess<TYPE> {                                    \
491     static TYPE& unsafe_reference(TaggedValue& t) { return t.data_.MEMBER; } \
492     static const TYPE& unsafe_reference(const TaggedValue& t) {              \
493       return t.data_.MEMBER;                                                 \
494     }                                                                        \
495   };
496 
497 TF_UNION_ACCESS_INSTANCE(impl::Capsule, capsule);
498 TF_UNION_ACCESS_INSTANCE(impl::Float32, f32);
499 TF_UNION_ACCESS_INSTANCE(impl::Int64, i64);
500 TF_UNION_ACCESS_INSTANCE(impl::ListPtr, list);
501 TF_UNION_ACCESS_INSTANCE(impl::TuplePtr, tuple);
502 TF_UNION_ACCESS_INSTANCE(impl::DictPtr, dict);
503 TF_UNION_ACCESS_INSTANCE(impl::Func, func);
504 TF_UNION_ACCESS_INSTANCE(impl::String, s);
505 TF_UNION_ACCESS_INSTANCE(impl::TaggedValueTensor, tensor);
506 TF_UNION_ACCESS_INSTANCE(impl::TensorSpec, tensor_spec);
507 #undef TF_UNION_ACCESS_INSTANCE
508 
509 /// The union accessor for `NoneType`.
510 template <>
511 struct TaggedValue::UnionAccess<impl::None> {
512   static impl::None& unsafe_reference(TaggedValue& t) {
513     return None::GetInstance();
514   }
515   static const impl::None& unsafe_reference(const TaggedValue& t) {
516     return None::GetInstance();
517   }
518 };
519 
520 /// @brief The Tuple class for holding tuples of TaggedValues.
521 /// TODO: Need to wrap vector in Tuple otherwise variant has duplicate types.
522 class Tuple {
523   using TU = std::vector<TaggedValue>;
524   using value_type = TU::value_type;
525   using iterator = TU::iterator;
526   using const_iterator = TU::const_iterator;
527   TU values_;
528 
529  public:
530   TU::iterator begin() { return values_.begin(); }
531   TU::iterator end() { return values_.end(); }
532   TU::const_iterator begin() const { return values_.begin(); }
533   TU::const_iterator end() const { return values_.end(); }
534   const TU::value_type& operator[](size_t i) const { return values_[i]; }
535   TU::value_type& operator[](size_t i) { return values_[i]; }
536   size_t size() const { return values_.size(); }
537   void emplace_back(TaggedValue v) { values_.emplace_back(std::move(v)); }
538   void push_back(const TaggedValue& v) { values_.push_back(v); }
539 };
540 
541 /// Hashing infrastructure for Tuple.
542 inline size_t TaggedValueHash<Tuple>::operator()(const Tuple& t) const {
543   std::size_t hash = 0;
544   for (auto& i : t) {
545     hash ^= TaggedValueHash<TaggedValue>()(i);
546   }
547   return hash;
548 }
549 
550 /// @brief The TaggedValueHashVisitor class for doing type-based hashing
551 /// of TaggedValues.
552 class TaggedValueHashVisitor {
553  public:
554   size_t operator()(const TaggedValueTensor& v) {
555     assert(false);
556     return 0;
557   }
558   size_t operator()(const ListPtr& v) {
559     assert(false);
560     return 0;
561   }
562   size_t operator()(const DictPtr& v) {
563     assert(false);
564     return 0;
565   }
566   size_t operator()(const Capsule& t) { return std::hash<Capsule>()(t); }
567   size_t operator()(const Func& t) {
568     assert(false);
569     return 0;
570   }
571   size_t operator()(const TuplePtr& t) {
572     std::size_t hash = 0;
573     for (auto it = t->begin(); it != t->end(); ++it) {
574       hash ^= TaggedValueHash<TaggedValue>()(*it);
575     }
576     return hash;
577   }
578   template <class T>
579   size_t operator()(const T& t) {
580     return absl::Hash<T>()(t);
581   }
582 };
583 
584 /// Hashing infrastructure for TaggedValues. Hashable TaggedValues overload
585 /// `AbslHashValue`. Non-hashable structures return 0, since we have no easy
586 /// way to abort.
587 inline size_t TaggedValueHash<TaggedValue>::operator()(
588     const TaggedValue& v) const {
589   return v.visit<size_t>(TaggedValueHashVisitor());
590 }
591 
592 }  // namespace impl
593 }  // namespace libtf
594 }  // namespace tf
595 
596 #endif  // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_
597