xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/experimental/libtf/object.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// @file object.h
16 /// @brief Object hierarchy for the TensorFlow C++ API. All "objects" are
17 /// derived from the `Handle` class. Instances of `Handle` are referred to as
18 /// "handles". All handles have a tagged value.
19 ///
20 /// Example Usage:
21 /// Object runtime = GetRuntime("tfrt");
22 /// Object module = runtime.Get("Import")("cool_mobilenet")
23 /// runtime.Get("Tensor")(Tuple(5,5,5), 3.3);
24 /// Object test = CreateModule("test");
25 /// test.Set("cool_function", callable);
26 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
27 #define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
28 
29 #include <string>
30 #include <utility>
31 
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/cc/experimental/libtf/value.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/status.h"
36 #include "tensorflow/core/platform/statusor.h"
37 
38 namespace tf {
39 namespace libtf {
40 
41 using TaggedValue = impl::TaggedValue;
42 class Handle;
43 
44 // Necessary forward declare.
45 template <class T>
46 Handle Convert(T value);
47 
48 /// @brief Base Handle class that wraps TaggedValue data. All data creation and
49 /// manipulation should done using Handle instances. Users should not be working
50 /// with TaggedValues directly.
51 
52 /// The `Handle` class contains a TaggedValue in the `value_` member, which
53 /// contains the underlying data. An object belonging to `Foo`, a derived class
54 /// of `Handle`, can be referred to as a `Foo` handle.
55 ///
56 /// It is important that all derived classes do not add any new data fields.
57 /// This ensures that it is always safe to slice down (i.e. assign an object of
58 /// a derived class to the base class) a handle to the base Handle class.
59 class Handle {
60  public:
61   /// Default constructor, which initializes a TaggedValue with type NONE.
Handle()62   Handle() : value_(TaggedValue::None()) {}
63 
64  public:
65   /// Constructs a handle from a TaggedValue.
Handle(TaggedValue value)66   explicit Handle(TaggedValue value) : value_(std::move(value)) {}
67   // explicit Handle(TaggedValue value, Handle* class_input)
68   //     : value_(std::move(value)), class_(class_input) {}
69   // const Handle& type() { return *class_; }
70 
71  protected:
72   /// The wrapped TaggedValue.
73   TaggedValue value_;
74   // effectively a "weak reference" to intern'd class value.
75   // types are compared by comparing pointer values here.
76   // Handle* class_;  // effectively a "weak reference" to intern'd class value.
77 
78   /// The Integer handle.
79   friend class Integer;
80   /// The Float handle.
81   friend class Float;
82   /// The String handle.
83   friend class String;
84   /// The Object handle.
85   friend class Object;
86   /// The List handle.
87   friend class List;
88   /// The Dictionary handle.
89   friend class Dictionary;
90   /// The Tuple handle.
91   friend class Tuple;
92   /// The Callable handle.
93   friend class Callable;
94   /// The Tensor handle.
95   friend class Tensor;
96   /// Converts a Handle instance to an instance of a derived class `T`.
97   template <class T>
98   friend tensorflow::StatusOr<T> Cast(Handle handle);
99   /// Infrastructure for converting a TaggedValue tuple function signature to an
100   /// unpacked variable list.
101   template <typename Fn, class TRET, class... ArgsOut>
102   friend class UneraseCallHelper;
103 };
104 
105 // Forward declare.
106 template <class T>
107 tensorflow::StatusOr<T> Cast(Handle handle);
108 
109 /// @brief The None class for holding TaggedValues of type NONE.
110 class None final : public Handle {
111  public:
112   /// Creates a handle that wraps a NONE TaggedValue.
None()113   None() : Handle(TaggedValue::None()) {}
114 
115  private:
None(TaggedValue v)116   explicit None(TaggedValue v) : Handle(std::move(v)) {}
117   template <class T>
118   friend tensorflow::StatusOr<T> Cast(Handle handle);
119 };
120 
121 /// @brief The String class for holding TaggedValues of type STRING.
122 class String final : public Handle {
123  public:
124   /// Creates a handle that wraps a STRING TaggedValue.
String(const char * s)125   explicit String(const char* s) : Handle(TaggedValue(s)) {}
126   /// Returns the underlying TaggedValue string.
get()127   const char* get() const { return value_.s(); }
128 
129  private:
130   // Private since it is in general unsafe.
String(TaggedValue v)131   explicit String(TaggedValue v) : Handle(std::move(v)) {}
132   template <class T>
133   friend tensorflow::StatusOr<T> Cast(Handle handle);
134 };
135 
136 /// @brief The `Object` class modeled after Python "objects".
137 ///
138 /// An `Object` uses a TaggedValue dictionary to store its attributes. The
139 /// "__parent__" attribute is reserved.
140 class Object : public Handle {
141  public:
142   /// Constructs a handle that acts as an object.
Object()143   Object() : Handle(TaggedValue::Dict()) {}
144   /// Retrieves the key of the object's parent.
145   static const String& ParentKey();
146 
147   /// @brief Gets an object member attribute`key`.
148   ///
149   /// If the `key` is not found in the object, the object's "__parent__"
150   /// attribute is then searched.
151   ///
152   /// @tparam T The desired return type.
153   /// @param key The key to look up.
154   /// @return `StatusOr` wrapping the key's value.
155   template <class T = Handle>
Get(const String & key)156   tensorflow::StatusOr<T> Get(const String& key) {
157     auto& dict = value_.dict();
158     auto it = dict.find(key.value_);
159     if (it != dict.end()) {
160       return Cast<T>(Handle(it->second));
161     } else {
162       // Lookup in object stored by reference in attribute  "__parent__".
163       auto it_class = dict.find(ParentKey().value_);
164       if (it_class != dict.end()) {
165         auto& class_dict_maybe = it_class->second;
166         if (class_dict_maybe.type() == TaggedValue::DICT) {
167           auto& dict = class_dict_maybe.dict();
168           auto it = dict.find(key.value_);
169           if (it != value_.dict().end()) {
170             return Cast<T>(Handle(it->second));
171           }
172         }
173       }
174     }
175     return tensorflow::errors::NotFound("Key not in dictionary.");
176   }
177 
178   /// Sets `key` attribute with the underlying value of `h`.
Set(const String & key,Handle h)179   void Set(const String& key, Handle h) {
180     value_.dict()[key.value_] = std::move(h.value_);
181   }
182 
183   /// Removes `key` from the object's attributes.
Unset(const String & key)184   void Unset(const String& key) { value_.dict().erase(key.value_); }
185   // TODO(b/): Adding dir() is in the future.
186  private:
187   // Private since it is in general unsafe.
Object(TaggedValue v)188   explicit Object(TaggedValue v) : Handle(std::move(v)) {}
189   template <class T>
190   friend tensorflow::StatusOr<T> Cast(Handle handle);
191 };
192 
193 /// @brief The Dictionary class for holding TaggedValues of type DICT.
194 class Dictionary final : public Handle {
195  public:
196   /// Constructs a handle that wraps a DICT TaggedValue.
Dictionary()197   Dictionary() : Handle(TaggedValue::Dict()) {}
198   // TODO(aselle): make this private to preserve invariant.
199 
200   /// Retrieves `key` with type `T`.
201   template <class T>
Get(const Handle & key)202   tensorflow::StatusOr<T> Get(const Handle& key) {
203     auto it = value_.dict().find(key.value_);
204     if (it != value_.dict().end()) return Cast<T>(Handle(it->second));
205     return tensorflow::errors::NotFound("Key not in dictionary.");
206   }
207   /// Sets `key` with value `value`.
Set(const String & key,Handle value)208   void Set(const String& key, Handle value) {
209     value_.dict()[key.value_] = std::move(value.value_);
210   }
211   /// Sets `key` with value `value`.
Set(const Handle & key,Handle value)212   void Set(const Handle& key, Handle value) {
213     value_.dict()[key.value_] = std::move(value.value_);
214   }
215   /// Retrieves size of dictionary.
size()216   size_t size() const { return value_.dict().size(); }
217 
218  private:
219   // Private since it is in general unsafe.
Dictionary(TaggedValue v)220   explicit Dictionary(TaggedValue v) : Handle(std::move(v)) {}
221   template <class T>
222   friend tensorflow::StatusOr<T> Cast(Handle handle);
223 };
224 
225 /// @brief The Integer class for holding TaggedValues of type INT.
226 class Integer final : public Handle {
227  public:
228   /// Creates a handle that wraps an INT TaggedValue.
Integer(Handle h)229   explicit Integer(Handle h) : Handle(h.value_) {}
230   /// Creates a handle that wraps an INT TaggedValue.
Integer(int64_t i)231   explicit Integer(int64_t i) : Handle(TaggedValue(i)) {}
232   /// Retrieves the underlying integer value.
get()233   int64_t get() const { return value_.i64().get(); }
234 
235  private:
236   // Private since it is in general unsafe.
Integer(TaggedValue v)237   explicit Integer(TaggedValue v) : Handle(std::move(v)) {}
238   template <class T>
239   friend tensorflow::StatusOr<T> Cast(Handle handle);
240 };
241 
242 /// @brief The Float class for holding TaggedValues of type FLOAT.
243 class Float final : public Handle {
244  public:
245   /// Constructs a Float handle that wraps a FLOAT TaggedValue.
Float(Handle h)246   explicit Float(Handle h) : Handle(h.value_) {}
247   /// Constructs a Float handle that wraps a FLOAT TaggedValue.
Float(float i)248   explicit Float(float i) : Handle(TaggedValue(i)) {}
249   /// Retrieves the underlying float value.
get()250   float get() const { return value_.f32().get(); }
251 
252  private:
253   // Private since it is in general unsafe.
Float(TaggedValue v)254   explicit Float(TaggedValue v) : Handle(std::move(v)) {}
255   template <class T>
256   friend tensorflow::StatusOr<T> Cast(Handle handle);
257 };
258 
259 /// @brief The Tensor class for holding TaggedValues of type TENSOR.
260 class Tensor final : public Handle {
261  public:
262   /// Constructs a Tensor handle from a Handle that wraps a TENSOR TaggedValue.
Tensor(Handle h)263   explicit Tensor(Handle h) : Handle(h.value_) {}
264 
265   /// @brief Retrieves the value of the Tensor handle.
266 
267   /// @param data Buffer in which to copy contents of the handle.
268   /// @throws InvalidArgument Raises error if `data` is of invalid size.
269   template <class T>
270   tensorflow::Status GetValue(absl::Span<T> data) const;
271 
272  private:
273   // Private since it is in general unsafe.
Tensor(TaggedValue v)274   explicit Tensor(TaggedValue v) : Handle(std::move(v)) {}
275   template <class T>
276   friend tensorflow::StatusOr<T> Cast(Handle handle);
277 };
278 
279 template <class T>
GetValue(absl::Span<T> data)280 tensorflow::Status Tensor::GetValue(absl::Span<T> data) const {
281   tensorflow::AbstractTensorPtr t;
282   {
283     const auto abstract_t = value_.tensor().get();
284     if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) {
285       return tensorflow::errors::InvalidArgument(
286           "Attempting to get value of non eager tensor.");
287     }
288     auto imm_t =
289         static_cast<tensorflow::ImmediateExecutionTensorHandle*>(abstract_t);
290     tensorflow::Status status;
291     t.reset(imm_t->Resolve(&status));
292     if (!status.ok()) {
293       return status;
294     }
295   }
296   if (data.size() != t->NumElements()) {
297     return tensorflow::errors::InvalidArgument(absl::StrCat(
298         "Mismatched number of elements: \n", "Expected: ", data.size(), "\n",
299         "Actual: ", t->NumElements(), "\n"));
300   }
301   memcpy(data.data(), t->Data(), t->ByteSize());
302   return ::tensorflow::OkStatus();
303 }
304 
305 /// @brief The Tuple class for holding TaggedValues of type TUPLE.
306 class Tuple : public Handle {
307  public:
308   /// Constructs a Tuple handle.
309   template <class... T>
Tuple(T...args)310   explicit Tuple(T... args) : Handle(TaggedValue::Tuple()) {
311     add(args...);
312   }
313 
314   /// Retrieves value at index `i`.
315   template <class T>
Get(size_t i)316   tensorflow::StatusOr<T> Get(size_t i) {
317     if (i >= value_.tuple().size())
318       return tensorflow::errors::InvalidArgument("Out of bounds index.");
319     return Cast<T>(Handle(value_.tuple()[i]));
320   }
321 
322   /// Retrieves number of elements.
size()323   size_t size() const { return value_.tuple().size(); }
324 
325  private:
326   // Add an item to a tuple. Should only be done by special construction
327   // like Callables (which are a friend).
add()328   void add() {}
329   template <class T, class... T2>
add(T arg,T2...args)330   void add(T arg, T2... args) {
331     value_.tuple().emplace_back(Convert(arg).value_);
332     add(args...);
333   }
334 
335   // Private since it is in general unsafe.
Tuple(TaggedValue v)336   explicit Tuple(TaggedValue v) : Handle(std::move(v)) {}
337   template <class T>
338   friend tensorflow::StatusOr<T> Cast(Handle handle);
339 };
340 
341 /// @brief The List class for holding TaggedValues of type LIST.
342 class List final : public Handle {
343  public:
344   /// Constructs a List handle.
345   template <class... T>
List(T...args)346   explicit List(T... args) : Handle(TaggedValue::List()) {}
347   /// Retrieves value at index `i`.
348   template <class T>
Get(size_t i)349   tensorflow::StatusOr<T> Get(size_t i) {
350     if (i >= size()) {
351       return tensorflow::errors::InvalidArgument("Out of bounds index.");
352     }
353     return Cast<T>(Handle(value_.list()[i]));
354   }
355 
356   /// Sets value `h` at index `i`.
Set(size_t i,Handle h)357   tensorflow::Status Set(size_t i, Handle h) {
358     if (i >= size()) {
359       return tensorflow::errors::InvalidArgument("Out of bounds index.");
360     }
361     value_.list()[i] = std::move(h.value_);
362     return ::tensorflow::OkStatus();
363   }
364 
365   /// Appends `arg` to list.
366   template <class T>
append(T arg)367   void append(T arg) {
368     value_.list().emplace_back(Convert(arg).value_);
369   }
370   /// Retrieves size of list.
size()371   size_t size() const { return value_.list().size(); }
372 
373  private:
374   // Private since it is in general unsafe.
List(TaggedValue v)375   explicit List(TaggedValue v) : Handle(std::move(v)) {}
376   template <class T>
377   friend tensorflow::StatusOr<T> Cast(Handle handle);
378 };
379 
380 /// @brief The `KeywordArg` class for storing keyword arguments as name value
381 /// pairs.
382 class KeywordArg {
383  public:
KeywordArg(const char * s)384   explicit KeywordArg(const char* s) : key_(String(s)), value_() {}
385 
386   template <class T>
387   KeywordArg& operator=(const T obj) {
388     value_ = Convert(obj);
389     return *this;
390   }
391 
392   friend class Callable;
393 
394  private:
395   String key_;
396   Handle value_;
397 };
398 
399 /// @brief The Callable class for creating callables.
400 class Callable final : public Handle {
401  private:
402   // Collect arguments for call
CollectArgs(Tuple & args,Dictionary & kwargs,int idx)403   void CollectArgs(Tuple& args, Dictionary& kwargs, int idx) {}
404   template <typename T, typename... Types>
CollectArgs(Tuple & args,Dictionary & kwargs,int idx,T v,Types...vars)405   void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, T v,
406                    Types... vars) {
407     const Handle& o = Convert(v);
408     args.value_.tuple().emplace_back(o.value_);
409     CollectArgs(args, kwargs, idx + 1, vars...);
410   }
411   template <typename... Types>
CollectArgs(Tuple & args,Dictionary & kwargs,int idx,KeywordArg v,Types...vars)412   void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, KeywordArg v,
413                    Types... vars) {
414     kwargs.Set(v.key_, v.value_);
415     CollectArgs(args, kwargs, idx + 1, vars...);
416   }
417 
418  public:
419   /// @brief Calls the wrapped TaggedValue function on a variable argument
420   /// list.
421   template <typename TReturn = Handle, typename... Types>
Call(Types...vars)422   tensorflow::StatusOr<TReturn> Call(Types... vars) {
423     Dictionary kwargs = Dictionary();
424     Tuple args;
425     CollectArgs(args, kwargs, 0, vars...);
426     auto maybe_value =
427         value_.func()(std::move(args.value_), std::move(kwargs.value_));
428     if (!maybe_value.ok()) {
429       return maybe_value.status();
430     }
431     return Cast<TReturn>(Handle(maybe_value.ValueOrDie()));
432   }
433 
434  public:
435   // TODO(aselle): need to find a way to write test w/o this being public.
436   // Private since it is in general unsafe.
Callable(TaggedValue v)437   explicit Callable(TaggedValue v) : Handle(std::move(v)) {}
438   template <class T>
439   friend tensorflow::StatusOr<T> Cast(Handle handle);
440 };
441 
442 namespace internal {
443 /// @brief The Capsule class for holding pointers.
444 class Capsule final : public Handle {
445  public:
446   /// Statically cast the TaggedValue capsule to type `T`.
447   template <class T>
cast()448   T cast() {
449     return static_cast<T>(value_.capsule());
450   }
451 
452  private:
453   // Private since it is in general unsafe.
Capsule(TaggedValue v)454   explicit Capsule(TaggedValue v) : Handle(std::move(v)) {}
455   template <class T>
456   friend tensorflow::StatusOr<T> tf::libtf::Cast(Handle handle);
457 };
458 }  // namespace internal
459 
460 /// @defgroup Util Functions for type conversion
461 ///
462 /// @brief Functions for retrieving and converting Handle types.
463 /// @{
464 
465 /// Retrieves tagged type of `T` handle.
466 template <class T>
TypeToTaggedType()467 inline TaggedValue::Type TypeToTaggedType() {}
468 /// Retrieves tagged type of base class handle.
469 template <>
470 inline TaggedValue::Type TypeToTaggedType<Handle>() {
471   return TaggedValue::Type::NONE;
472 }
473 /// Retrieves tagged type of None handle.
474 template <>
475 inline TaggedValue::Type TypeToTaggedType<None>() {
476   return TaggedValue::Type::NONE;
477 }
478 /// Retrieves tagged type of String handle.
479 template <>
480 inline TaggedValue::Type TypeToTaggedType<String>() {
481   return TaggedValue::Type::STRING;
482 }
483 /// Retrieves tagged type of Callable handle.
484 template <>
485 inline TaggedValue::Type TypeToTaggedType<Callable>() {
486   return TaggedValue::Type::FUNC;
487 }
488 /// Retrieves tagged type of Integer handle.
489 template <>
490 inline TaggedValue::Type TypeToTaggedType<Integer>() {
491   return TaggedValue::Type::INT64;
492 }
493 /// Retrieves tagged type of Float handle.
494 template <>
495 inline TaggedValue::Type TypeToTaggedType<Float>() {
496   return TaggedValue::Type::FLOAT32;
497 }
498 /// Retrieves tagged type of Object handle.
499 template <>
500 inline TaggedValue::Type TypeToTaggedType<Object>() {
501   return TaggedValue::Type::DICT;
502 }
503 /// Retrieves tagged type of Dictionary handle.
504 template <>
505 inline TaggedValue::Type TypeToTaggedType<Dictionary>() {
506   return TaggedValue::Type::DICT;
507 }
508 /// Retrieves tagged type of List handle.
509 template <>
510 inline TaggedValue::Type TypeToTaggedType<List>() {
511   return TaggedValue::Type::LIST;
512 }
513 /// Retrieves tagged type of Tensor handle.
514 template <>
515 inline TaggedValue::Type TypeToTaggedType<Tensor>() {
516   return TaggedValue::Type::TENSOR;
517 }
518 /// Retrieves tagged type of Capsule handle.
519 template <>
520 inline TaggedValue::Type TypeToTaggedType<internal::Capsule>() {
521   return TaggedValue::Type::CAPSULE;
522 }
523 // TODO(unknown): fully populate
524 
525 /// @brief Casts a handle to type `T`
526 ///
527 /// @param handle The handle to cast.
528 /// @tparam T The target handle type.
529 /// @exception InvalidArgument Raises error if the underlying TaggedValue type
530 /// of `handle` is not equivalent to `T`.
531 template <class T>
Cast(Handle handle)532 tensorflow::StatusOr<T> Cast(Handle handle) {
533   if (handle.value_.type() == TypeToTaggedType<T>() ||
534       std::is_same<T, Handle>::value)
535     return T((std::move(handle.value_)));
536   return tensorflow::errors::InvalidArgument("Incompatible cast.");
537 }
538 
539 // Converters for C++ primitives like float and int to handles. Allows callable
540 // calls and list appends to be more idiomatic.
541 
542 /// Converts a C++ const char* to a String handle.
543 template <>
Convert(const char * value)544 inline Handle Convert(const char* value) {
545   return String(value);
546 }
547 /// Converts a C++ int32_t to an Integer handle.
548 template <>
Convert(int32_t value)549 inline Handle Convert(int32_t value) {
550   return Integer(value);
551 }
552 /// Converts a C++ int64_t to an Integer handle.
553 template <>
Convert(int64_t value)554 inline Handle Convert(int64_t value) {
555   return Integer(value);
556 }
557 /// Converts a C++ float to an Integer handle.
558 template <>
Convert(float value)559 inline Handle Convert(float value) {
560   return Float(value);
561 }
562 /// Converts a value with primitive type T to a Handle.
563 template <class T>
Convert(T value)564 inline Handle Convert(T value) {
565   return Handle(std::move(value));
566 }
567 
568 /// @}
569 
570 // in the future it will be possible to make additional hard typed APIs
571 // by generating code by introspecting objects.
572 
573 // Here's a code gen'd example
574 // The dynamic structure can be turned into it.
575 /*
576 class Tf : Object {
577   Tensor ones(Tensor shape, String dtype);
578   // ...
579 }
580 */
581 
582 // Adapter to allow users to define Callables. Use TFLIB_CALLABLE_ADAPTOR
583 // instead.
584 template <typename TF, typename TReturn, typename... TFuncArgs>
585 class CallableWrapper;
586 
587 // Template extracts arguments from a lambda function. This base
588 // class definition inherits from a another specialization in order. We use
589 // this top level template to extract the function pointer associated with
590 // the created lambda functor class.
591 template <typename TLambda>
592 class CallableWrapperUnpackArgs
593     : public CallableWrapperUnpackArgs<decltype(&TLambda::operator())> {
594  public:
CallableWrapperUnpackArgs(TLambda fn,const char * name)595   CallableWrapperUnpackArgs(TLambda fn, const char* name)
596       : CallableWrapperUnpackArgs<decltype(&TLambda::operator())>(fn, name) {}
597 };
598 
599 // This specialization unpacks the arguments from a normal function pointer.
600 template <typename TReturn, typename... TFuncArgs>
601 class CallableWrapperUnpackArgs<TReturn (*)(TFuncArgs...)>
602     : public CallableWrapper<TReturn (*)(TFuncArgs...), TReturn, TFuncArgs...> {
603   using Fn = TReturn (*)(TFuncArgs...);
604 
605  public:
CallableWrapperUnpackArgs(Fn fn,const char * name)606   CallableWrapperUnpackArgs(Fn fn, const char* name)
607       : CallableWrapper<Fn, TReturn, TFuncArgs...>(fn, name) {}
608 };
609 
610 // This is the second stage of extracting the arguments from lambda function.
611 // NOTE: CallableWrapper's first template argument is the type of the
612 // function or functor (not the member pointer).
613 template <typename TClass, typename TReturn, typename... TFuncArgs>
614 class CallableWrapperUnpackArgs<TReturn (TClass::*)(TFuncArgs...) const>
615     : public CallableWrapper<TClass, TReturn, TFuncArgs...> {
616   using Fn = TClass;
617 
618  public:
CallableWrapperUnpackArgs(Fn fn,const char * name)619   CallableWrapperUnpackArgs(Fn fn, const char* name)
620       : CallableWrapper<Fn, TReturn, TFuncArgs...>(fn, name) {}
621 };
622 
623 template <class Fn, typename TReturn, class... ArgsOut>
624 class UneraseCallHelper;
625 
626 // UneraseCallHelper::Call allows transforming all the incoming arguments
627 // from a TaggedValue tuple to a variadic list of args.  The class template
628 // starts as a list of argument types and ends empty. The static member
629 // template starts empty and ends with the unerased types of the signature.
630 
631 // Base case (all arguments are processed, so call the function TFunc.
632 template <class Fn, typename TReturn>
633 class UneraseCallHelper<Fn, TReturn> {
634  public:
635   template <typename... ArgsOut>
Call(const char * name,Fn functor_,int argument_index,const TaggedValue & args_in,ArgsOut...args)636   static tensorflow::StatusOr<TaggedValue> Call(const char* name, Fn functor_,
637                                                 int argument_index,
638                                                 const TaggedValue& args_in,
639                                                 ArgsOut... args) {
640     // Call concrete type function
641     TReturn ret = functor_(args...);
642     return ret.value_;
643   }
644 };
645 
646 // Unpack a single argument case. Each argument is then cast.
647 template <class Fn, typename TReturn, class TSignatureArg,
648           class... TSignatureRest>
649 class UneraseCallHelper<Fn, TReturn, TSignatureArg, TSignatureRest...> {
650  public:
651   template <typename... TArgsOut>
Call(const char * name,Fn fn,int argument_index,TaggedValue & args_in,TArgsOut...args)652   static tensorflow::StatusOr<TaggedValue> Call(const char* name, Fn fn,
653                                                 int argument_index,
654                                                 TaggedValue& args_in,
655                                                 TArgsOut... args) {
656     Handle h(std::move(args_in.tuple()[argument_index]));
657     tensorflow::StatusOr<TSignatureArg> x = Cast<TSignatureArg>(std::move(h));
658     if (!x.ok())
659       return tensorflow::errors::InvalidArgument(
660           std::string("Function ") + name + " Arg " +
661           std::to_string(argument_index) +
662           " cannot be cast to desired signature type ");
663     return UneraseCallHelper<Fn, TReturn, TSignatureRest...>::template Call(
664         name, fn, argument_index + 1, args_in, args..., *x);
665   }
666 };
667 
668 // Template specialization that allows extracting arguments from a C function
669 // pointer.
670 template <class Fn, typename TReturn, typename... TFuncArgs>
671 class CallableWrapper {
672  private:
673   Fn functor_;
674   const char* name_;
675 
676  public:
CallableWrapper(Fn fn,const char * name)677   explicit CallableWrapper(Fn fn, const char* name)
678       : functor_(fn), name_(name) {}
679 
680   // Entry point of the Adaptor functor. Note args, and kwargs are attempted
681   // to be moved.
operator()682   tensorflow::StatusOr<TaggedValue> operator()(TaggedValue args,
683                                                TaggedValue kwargs) {
684     constexpr size_t argument_count = sizeof...(TFuncArgs);
685     if (argument_count != args.tuple().size())
686       return tensorflow::errors::InvalidArgument(
687           std::string("Function ") + name_ + " expected " +
688           std::to_string(argument_count) + " args.");
689     return UneraseCallHelper<Fn, TReturn, TFuncArgs...>::Call(name_, functor_,
690                                                               0, args);
691   }
692 };
693 
694 // Wrap a function that uses object handles as arguments and return types
695 // with one that takes TaggedValues. For example:
696 // Tuple Pack(Integer, Float, String);
697 // TaggedValue callable = TFLIB_CALLABLE_ADAPTOR(Pack);
698 #define TFLIB_CALLABLE_ADAPTOR(x) ::tf::libtf::CreateCallableAdaptor(x, #x)
699 
700 template <class TF>
CreateCallableAdaptor(TF x,const char * name)701 TaggedValue CreateCallableAdaptor(TF x, const char* name) {
702   return TaggedValue((CallableWrapperUnpackArgs<TF>(x, name)));
703 }
704 
705 }  // namespace libtf
706 }  // namespace tf
707 
708 #endif  // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
709