1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// @file value.h 16 /// @brief The TaggedValue struct that supports Python-like behavior in C++. 17 /// 18 /// The TaggedValue struct implements a tagged union data structure 19 /// (https://en.wikipedia.org/wiki/Tagged_union) in the TensorFlow C++ API. It 20 /// contains a `Type` enum (sometimes referred to as a "tag") 21 /// and a `Data` union for holding values. 22 23 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ 24 #define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ 25 26 #include <iostream> 27 #include <memory> 28 #include <utility> 29 #include <vector> 30 31 #include "absl/container/flat_hash_map.h" 32 #include "tensorflow/c/eager/abstract_tensor_handle.h" 33 #include "tensorflow/core/platform/intrusive_ptr.h" 34 #include "tensorflow/core/platform/statusor.h" 35 36 // TODO(b/195578409): Move all value objects into `impl`. Currently only values 37 // that do not reference TaggedValue are there. 38 #include "tensorflow/cc/experimental/libtf/impl/none.h" 39 #include "tensorflow/cc/experimental/libtf/impl/scalars.h" 40 #include "tensorflow/cc/experimental/libtf/impl/string.h" 41 #include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" 42 43 namespace tf { 44 namespace libtf { 45 namespace impl { 46 // Necessary forward declares. 47 class TaggedValue; 48 class Tuple; 49 template <class T> 50 // TODO(ccrusius): Use absl::Hash specializations instead. 51 class TaggedValueHash; 52 using List = std::vector<TaggedValue>; 53 using ListPtr = std::shared_ptr<List>; 54 using Dict = 55 absl::flat_hash_map<TaggedValue, TaggedValue, TaggedValueHash<TaggedValue>>; 56 using DictPtr = std::shared_ptr<Dict>; 57 using TuplePtr = std::shared_ptr<Tuple>; 58 using Func = 59 std::function<tensorflow::StatusOr<TaggedValue>(TaggedValue, TaggedValue)>; 60 // A capsule holds a pointer and a destructor for the pointer (i.e. a generic 61 // shared_ptr to void with a custom deleter). 62 using Capsule = std::shared_ptr<void>; 63 using TaggedValueTensor = 64 tensorflow::core::IntrusivePtr<tensorflow::AbstractTensorHandle>; 65 66 // Declare hash types so they can be instantiated below. 67 68 /// @brief TaggedValue hashing infrastructure, which uses absl::hash. 69 /// 70 /// Hashable TaggedValues overload `AbslHashValue`. Non-hashable structures 71 /// return 0. 72 template <> 73 struct TaggedValueHash<TaggedValue> { 74 size_t operator()(const TaggedValue& v) const; 75 }; 76 77 /// @brief Hash implementation for TaggedValue Tuples. 78 template <> 79 struct TaggedValueHash<Tuple> { 80 size_t operator()(const Tuple& t) const; 81 }; 82 83 /// @brief The basic `TaggedValue` tagged union type. 84 /// 85 /// A `TaggedValue` contains a `Type` (or "tag") as an enum and a `Value` union. 86 /// Values include tensors, primitive values, lists, tuples, and dictionaries. 87 /// In the future we might also want to have representation of python objects in 88 /// the form of PyObject*. 89 class TaggedValue final { 90 public: 91 /// @brief Enum that describes the possible types a `TaggedValue` can be. 92 /// 93 /// A `TaggedValue` must be one of the following types: NONE, INT64, FLOAT32, 94 /// STRING, FUNC, DICT, LIST, TUPLE, TENSOR, TENSOR_SPEC, CAPSULE. 95 enum Type { 96 NONE = 0, 97 INT64 = 1, 98 FLOAT32 = 2, 99 STRING = 3, 100 FUNC = 4, 101 DICT = 5, 102 LIST = 6, 103 TUPLE = 7, 104 TENSOR = 8, 105 TENSOR_SPEC = 9, 106 CAPSULE = 10, 107 }; 108 TaggedValue() : type_(NONE), data_() {} 109 110 /// Move assignment operator. 111 TaggedValue& operator=(TaggedValue&& v) { 112 destroy(); 113 MoveIntoUnion(std::move(v)); 114 return *this; 115 } 116 /// Move constructor. 117 TaggedValue(TaggedValue&& v) : type_(NONE) { MoveIntoUnion(std::move(v)); } 118 /// Copy constructor. 119 TaggedValue(const TaggedValue& v) : type_(NONE) { CopyIntoUnion(v); } 120 /// Copy assignment operator. 121 TaggedValue& operator=(const TaggedValue& v) { 122 destroy(); 123 CopyIntoUnion(v); 124 return *this; 125 } 126 /// TaggedValue constructor for type TENSOR. 127 explicit TaggedValue(TaggedValueTensor tensor) 128 : type_(TENSOR), data_(std::move(tensor)) {} 129 /// TaggedValue constructor for type TENSOR_SPEC. 130 explicit TaggedValue(tensorflow::PartialTensorShape shape, 131 tensorflow::DataType dtype) 132 : type_(TENSOR_SPEC), data_(shape, dtype) {} 133 /// TaggedValue constructor for type FUNC. 134 explicit TaggedValue(Func f32) : type_(FUNC), data_(f32) {} 135 /// TaggedValue constructor for type FLOAT32. 136 explicit TaggedValue(float f32) : type_(FLOAT32), data_(Float32(f32)) {} 137 /// TaggedValue constructor for type INT64. 138 explicit TaggedValue(int64_t i64) : type_(INT64), data_(Int64(i64)) {} 139 /// TaggedValue constructor for type FLOAT32. 140 explicit TaggedValue(Float32 f32) : type_(FLOAT32), data_(f32) {} 141 /// TaggedValue constructor for type INT64. 142 explicit TaggedValue(Int64 i64) : type_(INT64), data_(i64) {} 143 /// TaggedValue constructor for type STRING. 144 explicit TaggedValue(const char* s) : type_(STRING), data_(s) {} 145 /// Constructs a TaggedValue with type NONE. 146 static TaggedValue None() { 147 TaggedValue v; 148 v.type_ = NONE; 149 return v; 150 } 151 /// Constructs a TaggedValue with type LIST. 152 static TaggedValue List() { 153 TaggedValue v; 154 v.type_ = LIST; 155 using T = decltype(v.data_.list); 156 new (&v.data_.list) T(std::make_shared<T::element_type>()); 157 return v; 158 } 159 /// Constructs a TaggedValue with type TUPLE. 160 static TaggedValue Tuple() { 161 TaggedValue v; 162 v.type_ = TUPLE; 163 using T = decltype(v.data_.tuple); 164 new (&v.data_.tuple) T(std::make_shared<T::element_type>()); 165 return v; 166 } 167 /// Constructs a TaggedValue with type DICT. 168 static TaggedValue Dict() { 169 TaggedValue v; 170 v.type_ = DICT; 171 using T = decltype(v.data_.dict); 172 new (&v.data_.dict) T(std::make_shared<T::element_type>()); 173 return v; 174 } 175 /// Constructs a TaggedValue with type TENSOR. 176 static TaggedValue Tensor(tensorflow::AbstractTensorHandle* raw_ptr) { 177 TaggedValue v; 178 v.type_ = TENSOR; 179 using T = decltype(v.data_.tensor); 180 new (&v.data_.tensor) T(raw_ptr, /*add_ref=*/false); 181 return v; 182 } 183 184 /// Constructs a TaggedValue with type CAPSULE with a default destructor. 185 template <class T> 186 static TaggedValue Capsule(T* data) { 187 return Capsule(static_cast<void*>(data), 188 [](void* x) { delete static_cast<T*>(x); }); 189 } 190 /// Constructs a TaggedValue with type CAPSULE with a custom destructor. 191 static TaggedValue Capsule(void* data, void (*deleter)(void*)) { 192 TaggedValue v; 193 v.type_ = CAPSULE; 194 using T = decltype(v.data_.capsule); 195 new (&v.data_.capsule) T(data, deleter); 196 return v; 197 } 198 /// Destroys TaggedValue. Shared pointers in unions must be explicitly 199 /// deleted. 200 void destroy() { 201 if (type_ != NONE) { 202 // Explicitly run the destructor on the correct type. 203 visit<void>([](auto& x) { 204 using T = typename std::decay<decltype(x)>::type; 205 x.~T(); 206 }); 207 // Make the type None, whenever we destroy so we always have an 208 // initialized value. 209 type_ = NONE; 210 } 211 } 212 ~TaggedValue() { destroy(); } 213 214 /// @brief Get the underlying value based on type. 215 /// 216 /// @tparam T The desired return type. 217 /// @return The unwrapped value. If this `TaggedValue` type does not currently 218 /// contain a value of type `T`, the program terminates via a call to 219 /// `assert`. 220 template <typename T> 221 T& get() { 222 assert(type_ == EnumValueOf<T>::value); 223 return UnionAccess<T>::unsafe_reference(*this); 224 } 225 226 /// @brief Get the underlying value based on type. 227 /// 228 /// @tparam T The desired return type. 229 /// @return The unwrapped value. If this `TaggedValue` type does not currently 230 /// contain a value of type `T`, the program terminates via a call to 231 /// `assert`. 232 template <typename T> 233 const T& get() const { 234 assert(type_ == EnumValueOf<T>::value); 235 return UnionAccess<T>::unsafe_reference(*this); 236 } 237 238 /// Retrieves underlying value from a TaggedValue with type INT64. 239 const Int64& i64() const { return get<impl::Int64>(); } 240 241 /// Retrieves underlying value from a TaggedValue with type FLOAT32. 242 const Float32& f32() const { return get<impl::Float32>(); } 243 244 /// Retrieves underlying value from a TaggedValue with type STRING. 245 const char* s() const { return get<impl::String>().str().c_str(); } 246 247 /// Retrieves underlying value from a TaggedValue with type LIST. 248 impl::List& list() { return *get<impl::ListPtr>(); } 249 /// Retrieves underlying value from a TaggedValue with type LIST. 250 const impl::List& list() const { return *get<impl::ListPtr>(); } 251 252 /// Retrieves underlying value from a TaggedValue with type TUPLE. 253 impl::Tuple& tuple() { return *get<impl::TuplePtr>(); } 254 /// Retrieves underlying value from TaggedValues with type TUPLE. 255 const impl::Tuple& tuple() const { return *get<impl::TuplePtr>(); } 256 257 /// Retrieves underlying value from a TaggedValue with type DICT. 258 impl::Dict& dict() { return *get<impl::DictPtr>(); } 259 /// Retrieves underlying value from TaggedValues with type DICT. 260 const impl::Dict& dict() const { return *get<impl::DictPtr>(); } 261 262 /// Retrieves underlying value from a TaggedValue with type FUNC. 263 impl::Func func() const { return get<impl::Func>(); } 264 265 // TODO(danielellis): make const-only if possible, once the API allows for it 266 /// Retrieves underlying value from a TaggedValue with type TENSOR. 267 TaggedValueTensor& tensor() { return get<TaggedValueTensor>(); } 268 /// Retrieves underlying value from a TaggedValue with type TENSOR. 269 const TaggedValueTensor& tensor() const { return get<TaggedValueTensor>(); } 270 271 /// Retrieves underlying value from a TaggedValue with type TENSOR_SPEC. 272 const TensorSpec& tensor_spec() const { return get<TensorSpec>(); } 273 274 /// Retrieves underlying value from a TaggedValue with type CAPSULE. 275 void* capsule() const { return get<impl::Capsule>().get(); } 276 277 /// Retrieves type of TaggedValue. 278 Type type() const { return type_; } 279 280 /// @brief Implements equality operator for TaggedValue. 281 bool operator==(const TaggedValue& o) const { 282 if (type_ != o.type_) return false; 283 switch (type_) { 284 case LIST: 285 return data_.list == o.data_.list; 286 break; 287 case TUPLE: 288 return data_.tuple == o.data_.tuple; 289 break; 290 case DICT: 291 return data_.dict == o.data_.dict; 292 break; 293 case FUNC: 294 // TODO(b/187536093): This is definitely wrong because the exact ptr of 295 // the function pointer is almost always different, because we hold 296 // it by value. Two tagged values that hold the same std::function 297 // will have different std::function ptrs. operator== is not defined 298 // for std::function's so we need a better solution here, or these 299 // are not comparable which seems bad. 300 return &data_.func == &o.data_.func; 301 break; 302 case FLOAT32: 303 return data_.f32 == o.data_.f32; 304 break; 305 case INT64: 306 return data_.i64 == o.data_.i64; 307 break; 308 case STRING: 309 return data_.s == o.data_.s; 310 break; 311 case TENSOR: 312 return data_.tensor == o.data_.tensor; 313 case TENSOR_SPEC: 314 return data_.tensor_spec == o.data_.tensor_spec; 315 case CAPSULE: 316 return data_.capsule.get() == o.data_.capsule.get(); 317 case NONE: 318 return true; 319 } 320 } 321 322 /// @brief Implements visitor pattern for doing type-based dispatch. 323 /// 324 /// @tparam R The desired return type. 325 /// @tparam Visitor The visitor class which has a callable operator. 326 /// @return The `visitor` called on the correct value. 327 template <class R, class Visitor> 328 R visit(Visitor visitor) { 329 switch (type_) { 330 case LIST: 331 return visitor(data_.list); 332 case TUPLE: 333 return visitor(data_.tuple); 334 case DICT: 335 return visitor(data_.dict); 336 case FUNC: 337 return visitor(data_.func); 338 case FLOAT32: 339 return visitor(data_.f32); 340 case INT64: 341 return visitor(data_.i64); 342 case STRING: 343 return visitor(data_.s); 344 case TENSOR: 345 return visitor(data_.tensor); 346 case TENSOR_SPEC: 347 return visitor(data_.tensor_spec); 348 case CAPSULE: 349 return visitor(data_.capsule); 350 case NONE: 351 return visitor(impl::None::GetInstance()); 352 } 353 } 354 355 /// @brief Implements visitor pattern for doing type-based dispatch. 356 /// 357 /// @tparam R The desired return type. 358 /// @tparam Visitor The visitor class which has a callable operator. 359 /// @return The `visitor` called on the correct value. 360 template <class R, class Visitor> 361 R visit(Visitor visitor) const { 362 switch (type_) { 363 case LIST: 364 return visitor(data_.list); 365 case TUPLE: 366 return visitor(data_.tuple); 367 case DICT: 368 return visitor(data_.dict); 369 case FUNC: 370 return visitor(data_.func); 371 case FLOAT32: 372 return visitor(data_.f32); 373 case INT64: 374 return visitor(data_.i64); 375 case STRING: 376 return visitor(data_.s); 377 case TENSOR: 378 return visitor(data_.tensor); 379 case TENSOR_SPEC: 380 return visitor(data_.tensor_spec); 381 case CAPSULE: 382 return visitor(data_.capsule); 383 case NONE: 384 return visitor(impl::None::GetInstance()); 385 } 386 } 387 388 private: 389 /// @brief A utility class for mapping C++ types to Type values. 390 template <typename T> 391 struct EnumValueOf; 392 393 /// @brief A utility class for accessing the `Data` union members. 394 template <typename T> 395 struct UnionAccess; 396 397 // Unsafe Move, because it assumes the union has already been destroyed 398 // or is new! 399 void MoveIntoUnion(TaggedValue&& v) { 400 assert(type_ == NONE); 401 type_ = v.type_; 402 if (type_ != NONE) { 403 visit<void>([&v](auto& left) -> void { 404 using T = typename std::decay<decltype(left)>::type; 405 new (&left) T(std::move(UnionAccess<T>::unsafe_reference(v))); 406 }); 407 } 408 // Destroy the source r-value reference (making it None) 409 v.destroy(); 410 } 411 412 // Unsafe Move, because it assumes the union has already been destroyed 413 // or is new! 414 void CopyIntoUnion(const TaggedValue& v) { 415 assert(type_ == NONE); 416 type_ = v.type_; 417 if (type_ != NONE) { 418 visit<void>([&v](auto& left) -> void { 419 using T = typename std::decay<decltype(left)>::type; 420 new (&left) T(UnionAccess<T>::unsafe_reference(v)); 421 }); 422 } 423 } 424 425 /// @brief The type of the TaggedValue, i.e. the "tag" of a tagged union. 426 /// 427 /// In principle this could be incorporated into the union 428 /// for pointer types and non-64bit values, but then int64 and float64 values 429 /// would need to be indirected. This means that we are aiming for a total 430 /// data type size of <=16 bytes, comprised of one pointer (8 bytes) and 431 /// one type (<=8bytes). 432 Type type_; 433 434 // we use an explicit union here because we want to avoid C++17's 435 // variant structures due to c++14 compatibility requirements. 436 // TODO(b/183980966): Compare against absl::variant. 437 union Data { 438 explicit Data() {} 439 explicit Data(Float32 f32) : f32(f32) {} 440 explicit Data(Int64 i64) : i64(i64) {} 441 explicit Data(const char* s) : s(String(s)) {} 442 explicit Data(Func fn) : func(fn) {} 443 explicit Data(TaggedValueTensor tensor_in) { 444 new (&tensor) TaggedValueTensor(std::move(tensor_in)); 445 } 446 explicit Data(tensorflow::PartialTensorShape shape, 447 tensorflow::DataType dtype) 448 : tensor_spec({shape, dtype}) {} 449 ~Data() {} 450 Float32 f32; 451 Int64 i64; 452 String s; 453 Func func; 454 // TODO(aselle): look at tensorflow thing 455 std::shared_ptr<impl::Dict> dict; 456 std::shared_ptr<impl::List> list; 457 std::shared_ptr<impl::Tuple> tuple; 458 impl::Capsule capsule; 459 TaggedValueTensor tensor; 460 TensorSpec tensor_spec; 461 } data_; 462 friend std::ostream& operator<<(std::ostream& o, const TaggedValue& v); 463 friend TaggedValueHash<TaggedValue>; 464 }; 465 466 #define TF_ENUM_VALUE_OF(TYPE, ENUM) \ 467 template <> \ 468 struct TaggedValue::EnumValueOf<TYPE> { \ 469 static constexpr Type value = ENUM; \ 470 }; 471 472 TF_ENUM_VALUE_OF(impl::Capsule, CAPSULE); 473 TF_ENUM_VALUE_OF(impl::Float32, FLOAT32); 474 TF_ENUM_VALUE_OF(impl::Int64, INT64); 475 TF_ENUM_VALUE_OF(impl::List, LIST); 476 TF_ENUM_VALUE_OF(impl::ListPtr, LIST); 477 TF_ENUM_VALUE_OF(impl::Tuple, TUPLE); 478 TF_ENUM_VALUE_OF(impl::TuplePtr, TUPLE); 479 TF_ENUM_VALUE_OF(impl::Dict, DICT); 480 TF_ENUM_VALUE_OF(impl::DictPtr, DICT); 481 TF_ENUM_VALUE_OF(impl::None, NONE); 482 TF_ENUM_VALUE_OF(impl::Func, FUNC); 483 TF_ENUM_VALUE_OF(impl::String, STRING); 484 TF_ENUM_VALUE_OF(impl::TaggedValueTensor, TENSOR); 485 TF_ENUM_VALUE_OF(impl::TensorSpec, TENSOR_SPEC); 486 #undef TF_ENUM_VALUE_OF 487 488 #define TF_UNION_ACCESS_INSTANCE(TYPE, MEMBER) \ 489 template <> \ 490 struct TaggedValue::UnionAccess<TYPE> { \ 491 static TYPE& unsafe_reference(TaggedValue& t) { return t.data_.MEMBER; } \ 492 static const TYPE& unsafe_reference(const TaggedValue& t) { \ 493 return t.data_.MEMBER; \ 494 } \ 495 }; 496 497 TF_UNION_ACCESS_INSTANCE(impl::Capsule, capsule); 498 TF_UNION_ACCESS_INSTANCE(impl::Float32, f32); 499 TF_UNION_ACCESS_INSTANCE(impl::Int64, i64); 500 TF_UNION_ACCESS_INSTANCE(impl::ListPtr, list); 501 TF_UNION_ACCESS_INSTANCE(impl::TuplePtr, tuple); 502 TF_UNION_ACCESS_INSTANCE(impl::DictPtr, dict); 503 TF_UNION_ACCESS_INSTANCE(impl::Func, func); 504 TF_UNION_ACCESS_INSTANCE(impl::String, s); 505 TF_UNION_ACCESS_INSTANCE(impl::TaggedValueTensor, tensor); 506 TF_UNION_ACCESS_INSTANCE(impl::TensorSpec, tensor_spec); 507 #undef TF_UNION_ACCESS_INSTANCE 508 509 /// The union accessor for `NoneType`. 510 template <> 511 struct TaggedValue::UnionAccess<impl::None> { 512 static impl::None& unsafe_reference(TaggedValue& t) { 513 return None::GetInstance(); 514 } 515 static const impl::None& unsafe_reference(const TaggedValue& t) { 516 return None::GetInstance(); 517 } 518 }; 519 520 /// @brief The Tuple class for holding tuples of TaggedValues. 521 /// TODO: Need to wrap vector in Tuple otherwise variant has duplicate types. 522 class Tuple { 523 using TU = std::vector<TaggedValue>; 524 using value_type = TU::value_type; 525 using iterator = TU::iterator; 526 using const_iterator = TU::const_iterator; 527 TU values_; 528 529 public: 530 TU::iterator begin() { return values_.begin(); } 531 TU::iterator end() { return values_.end(); } 532 TU::const_iterator begin() const { return values_.begin(); } 533 TU::const_iterator end() const { return values_.end(); } 534 const TU::value_type& operator[](size_t i) const { return values_[i]; } 535 TU::value_type& operator[](size_t i) { return values_[i]; } 536 size_t size() const { return values_.size(); } 537 void emplace_back(TaggedValue v) { values_.emplace_back(std::move(v)); } 538 void push_back(const TaggedValue& v) { values_.push_back(v); } 539 }; 540 541 /// Hashing infrastructure for Tuple. 542 inline size_t TaggedValueHash<Tuple>::operator()(const Tuple& t) const { 543 std::size_t hash = 0; 544 for (auto& i : t) { 545 hash ^= TaggedValueHash<TaggedValue>()(i); 546 } 547 return hash; 548 } 549 550 /// @brief The TaggedValueHashVisitor class for doing type-based hashing 551 /// of TaggedValues. 552 class TaggedValueHashVisitor { 553 public: 554 size_t operator()(const TaggedValueTensor& v) { 555 assert(false); 556 return 0; 557 } 558 size_t operator()(const ListPtr& v) { 559 assert(false); 560 return 0; 561 } 562 size_t operator()(const DictPtr& v) { 563 assert(false); 564 return 0; 565 } 566 size_t operator()(const Capsule& t) { return std::hash<Capsule>()(t); } 567 size_t operator()(const Func& t) { 568 assert(false); 569 return 0; 570 } 571 size_t operator()(const TuplePtr& t) { 572 std::size_t hash = 0; 573 for (auto it = t->begin(); it != t->end(); ++it) { 574 hash ^= TaggedValueHash<TaggedValue>()(*it); 575 } 576 return hash; 577 } 578 template <class T> 579 size_t operator()(const T& t) { 580 return absl::Hash<T>()(t); 581 } 582 }; 583 584 /// Hashing infrastructure for TaggedValues. Hashable TaggedValues overload 585 /// `AbslHashValue`. Non-hashable structures return 0, since we have no easy 586 /// way to abort. 587 inline size_t TaggedValueHash<TaggedValue>::operator()( 588 const TaggedValue& v) const { 589 return v.visit<size_t>(TaggedValueHashVisitor()); 590 } 591 592 } // namespace impl 593 } // namespace libtf 594 } // namespace tf 595 596 #endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ 597