1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// @file object.h
16 /// @brief Object hierarchy for the TensorFlow C++ API. All "objects" are
17 /// derived from the `Handle` class. Instances of `Handle` are referred to as
18 /// "handles". All handles have a tagged value.
19 ///
20 /// Example Usage:
21 /// Object runtime = GetRuntime("tfrt");
22 /// Object module = runtime.Get("Import")("cool_mobilenet")
23 /// runtime.Get("Tensor")(Tuple(5,5,5), 3.3);
24 /// Object test = CreateModule("test");
25 /// test.Set("cool_function", callable);
26 #ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
27 #define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
28
29 #include <string>
30 #include <utility>
31
32 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33 #include "tensorflow/cc/experimental/libtf/value.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/status.h"
36 #include "tensorflow/core/platform/statusor.h"
37
38 namespace tf {
39 namespace libtf {
40
41 using TaggedValue = impl::TaggedValue;
42 class Handle;
43
44 // Necessary forward declare.
45 template <class T>
46 Handle Convert(T value);
47
48 /// @brief Base Handle class that wraps TaggedValue data. All data creation and
49 /// manipulation should done using Handle instances. Users should not be working
50 /// with TaggedValues directly.
51
52 /// The `Handle` class contains a TaggedValue in the `value_` member, which
53 /// contains the underlying data. An object belonging to `Foo`, a derived class
54 /// of `Handle`, can be referred to as a `Foo` handle.
55 ///
56 /// It is important that all derived classes do not add any new data fields.
57 /// This ensures that it is always safe to slice down (i.e. assign an object of
58 /// a derived class to the base class) a handle to the base Handle class.
59 class Handle {
60 public:
61 /// Default constructor, which initializes a TaggedValue with type NONE.
Handle()62 Handle() : value_(TaggedValue::None()) {}
63
64 public:
65 /// Constructs a handle from a TaggedValue.
Handle(TaggedValue value)66 explicit Handle(TaggedValue value) : value_(std::move(value)) {}
67 // explicit Handle(TaggedValue value, Handle* class_input)
68 // : value_(std::move(value)), class_(class_input) {}
69 // const Handle& type() { return *class_; }
70
71 protected:
72 /// The wrapped TaggedValue.
73 TaggedValue value_;
74 // effectively a "weak reference" to intern'd class value.
75 // types are compared by comparing pointer values here.
76 // Handle* class_; // effectively a "weak reference" to intern'd class value.
77
78 /// The Integer handle.
79 friend class Integer;
80 /// The Float handle.
81 friend class Float;
82 /// The String handle.
83 friend class String;
84 /// The Object handle.
85 friend class Object;
86 /// The List handle.
87 friend class List;
88 /// The Dictionary handle.
89 friend class Dictionary;
90 /// The Tuple handle.
91 friend class Tuple;
92 /// The Callable handle.
93 friend class Callable;
94 /// The Tensor handle.
95 friend class Tensor;
96 /// Converts a Handle instance to an instance of a derived class `T`.
97 template <class T>
98 friend tensorflow::StatusOr<T> Cast(Handle handle);
99 /// Infrastructure for converting a TaggedValue tuple function signature to an
100 /// unpacked variable list.
101 template <typename Fn, class TRET, class... ArgsOut>
102 friend class UneraseCallHelper;
103 };
104
105 // Forward declare.
106 template <class T>
107 tensorflow::StatusOr<T> Cast(Handle handle);
108
109 /// @brief The None class for holding TaggedValues of type NONE.
110 class None final : public Handle {
111 public:
112 /// Creates a handle that wraps a NONE TaggedValue.
None()113 None() : Handle(TaggedValue::None()) {}
114
115 private:
None(TaggedValue v)116 explicit None(TaggedValue v) : Handle(std::move(v)) {}
117 template <class T>
118 friend tensorflow::StatusOr<T> Cast(Handle handle);
119 };
120
121 /// @brief The String class for holding TaggedValues of type STRING.
122 class String final : public Handle {
123 public:
124 /// Creates a handle that wraps a STRING TaggedValue.
String(const char * s)125 explicit String(const char* s) : Handle(TaggedValue(s)) {}
126 /// Returns the underlying TaggedValue string.
get()127 const char* get() const { return value_.s(); }
128
129 private:
130 // Private since it is in general unsafe.
String(TaggedValue v)131 explicit String(TaggedValue v) : Handle(std::move(v)) {}
132 template <class T>
133 friend tensorflow::StatusOr<T> Cast(Handle handle);
134 };
135
136 /// @brief The `Object` class modeled after Python "objects".
137 ///
138 /// An `Object` uses a TaggedValue dictionary to store its attributes. The
139 /// "__parent__" attribute is reserved.
140 class Object : public Handle {
141 public:
142 /// Constructs a handle that acts as an object.
Object()143 Object() : Handle(TaggedValue::Dict()) {}
144 /// Retrieves the key of the object's parent.
145 static const String& ParentKey();
146
147 /// @brief Gets an object member attribute`key`.
148 ///
149 /// If the `key` is not found in the object, the object's "__parent__"
150 /// attribute is then searched.
151 ///
152 /// @tparam T The desired return type.
153 /// @param key The key to look up.
154 /// @return `StatusOr` wrapping the key's value.
155 template <class T = Handle>
Get(const String & key)156 tensorflow::StatusOr<T> Get(const String& key) {
157 auto& dict = value_.dict();
158 auto it = dict.find(key.value_);
159 if (it != dict.end()) {
160 return Cast<T>(Handle(it->second));
161 } else {
162 // Lookup in object stored by reference in attribute "__parent__".
163 auto it_class = dict.find(ParentKey().value_);
164 if (it_class != dict.end()) {
165 auto& class_dict_maybe = it_class->second;
166 if (class_dict_maybe.type() == TaggedValue::DICT) {
167 auto& dict = class_dict_maybe.dict();
168 auto it = dict.find(key.value_);
169 if (it != value_.dict().end()) {
170 return Cast<T>(Handle(it->second));
171 }
172 }
173 }
174 }
175 return tensorflow::errors::NotFound("Key not in dictionary.");
176 }
177
178 /// Sets `key` attribute with the underlying value of `h`.
Set(const String & key,Handle h)179 void Set(const String& key, Handle h) {
180 value_.dict()[key.value_] = std::move(h.value_);
181 }
182
183 /// Removes `key` from the object's attributes.
Unset(const String & key)184 void Unset(const String& key) { value_.dict().erase(key.value_); }
185 // TODO(b/): Adding dir() is in the future.
186 private:
187 // Private since it is in general unsafe.
Object(TaggedValue v)188 explicit Object(TaggedValue v) : Handle(std::move(v)) {}
189 template <class T>
190 friend tensorflow::StatusOr<T> Cast(Handle handle);
191 };
192
193 /// @brief The Dictionary class for holding TaggedValues of type DICT.
194 class Dictionary final : public Handle {
195 public:
196 /// Constructs a handle that wraps a DICT TaggedValue.
Dictionary()197 Dictionary() : Handle(TaggedValue::Dict()) {}
198 // TODO(aselle): make this private to preserve invariant.
199
200 /// Retrieves `key` with type `T`.
201 template <class T>
Get(const Handle & key)202 tensorflow::StatusOr<T> Get(const Handle& key) {
203 auto it = value_.dict().find(key.value_);
204 if (it != value_.dict().end()) return Cast<T>(Handle(it->second));
205 return tensorflow::errors::NotFound("Key not in dictionary.");
206 }
207 /// Sets `key` with value `value`.
Set(const String & key,Handle value)208 void Set(const String& key, Handle value) {
209 value_.dict()[key.value_] = std::move(value.value_);
210 }
211 /// Sets `key` with value `value`.
Set(const Handle & key,Handle value)212 void Set(const Handle& key, Handle value) {
213 value_.dict()[key.value_] = std::move(value.value_);
214 }
215 /// Retrieves size of dictionary.
size()216 size_t size() const { return value_.dict().size(); }
217
218 private:
219 // Private since it is in general unsafe.
Dictionary(TaggedValue v)220 explicit Dictionary(TaggedValue v) : Handle(std::move(v)) {}
221 template <class T>
222 friend tensorflow::StatusOr<T> Cast(Handle handle);
223 };
224
225 /// @brief The Integer class for holding TaggedValues of type INT.
226 class Integer final : public Handle {
227 public:
228 /// Creates a handle that wraps an INT TaggedValue.
Integer(Handle h)229 explicit Integer(Handle h) : Handle(h.value_) {}
230 /// Creates a handle that wraps an INT TaggedValue.
Integer(int64_t i)231 explicit Integer(int64_t i) : Handle(TaggedValue(i)) {}
232 /// Retrieves the underlying integer value.
get()233 int64_t get() const { return value_.i64().get(); }
234
235 private:
236 // Private since it is in general unsafe.
Integer(TaggedValue v)237 explicit Integer(TaggedValue v) : Handle(std::move(v)) {}
238 template <class T>
239 friend tensorflow::StatusOr<T> Cast(Handle handle);
240 };
241
242 /// @brief The Float class for holding TaggedValues of type FLOAT.
243 class Float final : public Handle {
244 public:
245 /// Constructs a Float handle that wraps a FLOAT TaggedValue.
Float(Handle h)246 explicit Float(Handle h) : Handle(h.value_) {}
247 /// Constructs a Float handle that wraps a FLOAT TaggedValue.
Float(float i)248 explicit Float(float i) : Handle(TaggedValue(i)) {}
249 /// Retrieves the underlying float value.
get()250 float get() const { return value_.f32().get(); }
251
252 private:
253 // Private since it is in general unsafe.
Float(TaggedValue v)254 explicit Float(TaggedValue v) : Handle(std::move(v)) {}
255 template <class T>
256 friend tensorflow::StatusOr<T> Cast(Handle handle);
257 };
258
259 /// @brief The Tensor class for holding TaggedValues of type TENSOR.
260 class Tensor final : public Handle {
261 public:
262 /// Constructs a Tensor handle from a Handle that wraps a TENSOR TaggedValue.
Tensor(Handle h)263 explicit Tensor(Handle h) : Handle(h.value_) {}
264
265 /// @brief Retrieves the value of the Tensor handle.
266
267 /// @param data Buffer in which to copy contents of the handle.
268 /// @throws InvalidArgument Raises error if `data` is of invalid size.
269 template <class T>
270 tensorflow::Status GetValue(absl::Span<T> data) const;
271
272 private:
273 // Private since it is in general unsafe.
Tensor(TaggedValue v)274 explicit Tensor(TaggedValue v) : Handle(std::move(v)) {}
275 template <class T>
276 friend tensorflow::StatusOr<T> Cast(Handle handle);
277 };
278
279 template <class T>
GetValue(absl::Span<T> data)280 tensorflow::Status Tensor::GetValue(absl::Span<T> data) const {
281 tensorflow::AbstractTensorPtr t;
282 {
283 const auto abstract_t = value_.tensor().get();
284 if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) {
285 return tensorflow::errors::InvalidArgument(
286 "Attempting to get value of non eager tensor.");
287 }
288 auto imm_t =
289 static_cast<tensorflow::ImmediateExecutionTensorHandle*>(abstract_t);
290 tensorflow::Status status;
291 t.reset(imm_t->Resolve(&status));
292 if (!status.ok()) {
293 return status;
294 }
295 }
296 if (data.size() != t->NumElements()) {
297 return tensorflow::errors::InvalidArgument(absl::StrCat(
298 "Mismatched number of elements: \n", "Expected: ", data.size(), "\n",
299 "Actual: ", t->NumElements(), "\n"));
300 }
301 memcpy(data.data(), t->Data(), t->ByteSize());
302 return ::tensorflow::OkStatus();
303 }
304
305 /// @brief The Tuple class for holding TaggedValues of type TUPLE.
306 class Tuple : public Handle {
307 public:
308 /// Constructs a Tuple handle.
309 template <class... T>
Tuple(T...args)310 explicit Tuple(T... args) : Handle(TaggedValue::Tuple()) {
311 add(args...);
312 }
313
314 /// Retrieves value at index `i`.
315 template <class T>
Get(size_t i)316 tensorflow::StatusOr<T> Get(size_t i) {
317 if (i >= value_.tuple().size())
318 return tensorflow::errors::InvalidArgument("Out of bounds index.");
319 return Cast<T>(Handle(value_.tuple()[i]));
320 }
321
322 /// Retrieves number of elements.
size()323 size_t size() const { return value_.tuple().size(); }
324
325 private:
326 // Add an item to a tuple. Should only be done by special construction
327 // like Callables (which are a friend).
add()328 void add() {}
329 template <class T, class... T2>
add(T arg,T2...args)330 void add(T arg, T2... args) {
331 value_.tuple().emplace_back(Convert(arg).value_);
332 add(args...);
333 }
334
335 // Private since it is in general unsafe.
Tuple(TaggedValue v)336 explicit Tuple(TaggedValue v) : Handle(std::move(v)) {}
337 template <class T>
338 friend tensorflow::StatusOr<T> Cast(Handle handle);
339 };
340
341 /// @brief The List class for holding TaggedValues of type LIST.
342 class List final : public Handle {
343 public:
344 /// Constructs a List handle.
345 template <class... T>
List(T...args)346 explicit List(T... args) : Handle(TaggedValue::List()) {}
347 /// Retrieves value at index `i`.
348 template <class T>
Get(size_t i)349 tensorflow::StatusOr<T> Get(size_t i) {
350 if (i >= size()) {
351 return tensorflow::errors::InvalidArgument("Out of bounds index.");
352 }
353 return Cast<T>(Handle(value_.list()[i]));
354 }
355
356 /// Sets value `h` at index `i`.
Set(size_t i,Handle h)357 tensorflow::Status Set(size_t i, Handle h) {
358 if (i >= size()) {
359 return tensorflow::errors::InvalidArgument("Out of bounds index.");
360 }
361 value_.list()[i] = std::move(h.value_);
362 return ::tensorflow::OkStatus();
363 }
364
365 /// Appends `arg` to list.
366 template <class T>
append(T arg)367 void append(T arg) {
368 value_.list().emplace_back(Convert(arg).value_);
369 }
370 /// Retrieves size of list.
size()371 size_t size() const { return value_.list().size(); }
372
373 private:
374 // Private since it is in general unsafe.
List(TaggedValue v)375 explicit List(TaggedValue v) : Handle(std::move(v)) {}
376 template <class T>
377 friend tensorflow::StatusOr<T> Cast(Handle handle);
378 };
379
380 /// @brief The `KeywordArg` class for storing keyword arguments as name value
381 /// pairs.
382 class KeywordArg {
383 public:
KeywordArg(const char * s)384 explicit KeywordArg(const char* s) : key_(String(s)), value_() {}
385
386 template <class T>
387 KeywordArg& operator=(const T obj) {
388 value_ = Convert(obj);
389 return *this;
390 }
391
392 friend class Callable;
393
394 private:
395 String key_;
396 Handle value_;
397 };
398
399 /// @brief The Callable class for creating callables.
400 class Callable final : public Handle {
401 private:
402 // Collect arguments for call
CollectArgs(Tuple & args,Dictionary & kwargs,int idx)403 void CollectArgs(Tuple& args, Dictionary& kwargs, int idx) {}
404 template <typename T, typename... Types>
CollectArgs(Tuple & args,Dictionary & kwargs,int idx,T v,Types...vars)405 void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, T v,
406 Types... vars) {
407 const Handle& o = Convert(v);
408 args.value_.tuple().emplace_back(o.value_);
409 CollectArgs(args, kwargs, idx + 1, vars...);
410 }
411 template <typename... Types>
CollectArgs(Tuple & args,Dictionary & kwargs,int idx,KeywordArg v,Types...vars)412 void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, KeywordArg v,
413 Types... vars) {
414 kwargs.Set(v.key_, v.value_);
415 CollectArgs(args, kwargs, idx + 1, vars...);
416 }
417
418 public:
419 /// @brief Calls the wrapped TaggedValue function on a variable argument
420 /// list.
421 template <typename TReturn = Handle, typename... Types>
Call(Types...vars)422 tensorflow::StatusOr<TReturn> Call(Types... vars) {
423 Dictionary kwargs = Dictionary();
424 Tuple args;
425 CollectArgs(args, kwargs, 0, vars...);
426 auto maybe_value =
427 value_.func()(std::move(args.value_), std::move(kwargs.value_));
428 if (!maybe_value.ok()) {
429 return maybe_value.status();
430 }
431 return Cast<TReturn>(Handle(maybe_value.ValueOrDie()));
432 }
433
434 public:
435 // TODO(aselle): need to find a way to write test w/o this being public.
436 // Private since it is in general unsafe.
Callable(TaggedValue v)437 explicit Callable(TaggedValue v) : Handle(std::move(v)) {}
438 template <class T>
439 friend tensorflow::StatusOr<T> Cast(Handle handle);
440 };
441
442 namespace internal {
443 /// @brief The Capsule class for holding pointers.
444 class Capsule final : public Handle {
445 public:
446 /// Statically cast the TaggedValue capsule to type `T`.
447 template <class T>
cast()448 T cast() {
449 return static_cast<T>(value_.capsule());
450 }
451
452 private:
453 // Private since it is in general unsafe.
Capsule(TaggedValue v)454 explicit Capsule(TaggedValue v) : Handle(std::move(v)) {}
455 template <class T>
456 friend tensorflow::StatusOr<T> tf::libtf::Cast(Handle handle);
457 };
458 } // namespace internal
459
460 /// @defgroup Util Functions for type conversion
461 ///
462 /// @brief Functions for retrieving and converting Handle types.
463 /// @{
464
465 /// Retrieves tagged type of `T` handle.
466 template <class T>
TypeToTaggedType()467 inline TaggedValue::Type TypeToTaggedType() {}
468 /// Retrieves tagged type of base class handle.
469 template <>
470 inline TaggedValue::Type TypeToTaggedType<Handle>() {
471 return TaggedValue::Type::NONE;
472 }
473 /// Retrieves tagged type of None handle.
474 template <>
475 inline TaggedValue::Type TypeToTaggedType<None>() {
476 return TaggedValue::Type::NONE;
477 }
478 /// Retrieves tagged type of String handle.
479 template <>
480 inline TaggedValue::Type TypeToTaggedType<String>() {
481 return TaggedValue::Type::STRING;
482 }
483 /// Retrieves tagged type of Callable handle.
484 template <>
485 inline TaggedValue::Type TypeToTaggedType<Callable>() {
486 return TaggedValue::Type::FUNC;
487 }
488 /// Retrieves tagged type of Integer handle.
489 template <>
490 inline TaggedValue::Type TypeToTaggedType<Integer>() {
491 return TaggedValue::Type::INT64;
492 }
493 /// Retrieves tagged type of Float handle.
494 template <>
495 inline TaggedValue::Type TypeToTaggedType<Float>() {
496 return TaggedValue::Type::FLOAT32;
497 }
498 /// Retrieves tagged type of Object handle.
499 template <>
500 inline TaggedValue::Type TypeToTaggedType<Object>() {
501 return TaggedValue::Type::DICT;
502 }
503 /// Retrieves tagged type of Dictionary handle.
504 template <>
505 inline TaggedValue::Type TypeToTaggedType<Dictionary>() {
506 return TaggedValue::Type::DICT;
507 }
508 /// Retrieves tagged type of List handle.
509 template <>
510 inline TaggedValue::Type TypeToTaggedType<List>() {
511 return TaggedValue::Type::LIST;
512 }
513 /// Retrieves tagged type of Tensor handle.
514 template <>
515 inline TaggedValue::Type TypeToTaggedType<Tensor>() {
516 return TaggedValue::Type::TENSOR;
517 }
518 /// Retrieves tagged type of Capsule handle.
519 template <>
520 inline TaggedValue::Type TypeToTaggedType<internal::Capsule>() {
521 return TaggedValue::Type::CAPSULE;
522 }
523 // TODO(unknown): fully populate
524
525 /// @brief Casts a handle to type `T`
526 ///
527 /// @param handle The handle to cast.
528 /// @tparam T The target handle type.
529 /// @exception InvalidArgument Raises error if the underlying TaggedValue type
530 /// of `handle` is not equivalent to `T`.
531 template <class T>
Cast(Handle handle)532 tensorflow::StatusOr<T> Cast(Handle handle) {
533 if (handle.value_.type() == TypeToTaggedType<T>() ||
534 std::is_same<T, Handle>::value)
535 return T((std::move(handle.value_)));
536 return tensorflow::errors::InvalidArgument("Incompatible cast.");
537 }
538
539 // Converters for C++ primitives like float and int to handles. Allows callable
540 // calls and list appends to be more idiomatic.
541
542 /// Converts a C++ const char* to a String handle.
543 template <>
Convert(const char * value)544 inline Handle Convert(const char* value) {
545 return String(value);
546 }
547 /// Converts a C++ int32_t to an Integer handle.
548 template <>
Convert(int32_t value)549 inline Handle Convert(int32_t value) {
550 return Integer(value);
551 }
552 /// Converts a C++ int64_t to an Integer handle.
553 template <>
Convert(int64_t value)554 inline Handle Convert(int64_t value) {
555 return Integer(value);
556 }
557 /// Converts a C++ float to an Integer handle.
558 template <>
Convert(float value)559 inline Handle Convert(float value) {
560 return Float(value);
561 }
562 /// Converts a value with primitive type T to a Handle.
563 template <class T>
Convert(T value)564 inline Handle Convert(T value) {
565 return Handle(std::move(value));
566 }
567
568 /// @}
569
570 // in the future it will be possible to make additional hard typed APIs
571 // by generating code by introspecting objects.
572
573 // Here's a code gen'd example
574 // The dynamic structure can be turned into it.
575 /*
576 class Tf : Object {
577 Tensor ones(Tensor shape, String dtype);
578 // ...
579 }
580 */
581
582 // Adapter to allow users to define Callables. Use TFLIB_CALLABLE_ADAPTOR
583 // instead.
584 template <typename TF, typename TReturn, typename... TFuncArgs>
585 class CallableWrapper;
586
587 // Template extracts arguments from a lambda function. This base
588 // class definition inherits from a another specialization in order. We use
589 // this top level template to extract the function pointer associated with
590 // the created lambda functor class.
591 template <typename TLambda>
592 class CallableWrapperUnpackArgs
593 : public CallableWrapperUnpackArgs<decltype(&TLambda::operator())> {
594 public:
CallableWrapperUnpackArgs(TLambda fn,const char * name)595 CallableWrapperUnpackArgs(TLambda fn, const char* name)
596 : CallableWrapperUnpackArgs<decltype(&TLambda::operator())>(fn, name) {}
597 };
598
599 // This specialization unpacks the arguments from a normal function pointer.
600 template <typename TReturn, typename... TFuncArgs>
601 class CallableWrapperUnpackArgs<TReturn (*)(TFuncArgs...)>
602 : public CallableWrapper<TReturn (*)(TFuncArgs...), TReturn, TFuncArgs...> {
603 using Fn = TReturn (*)(TFuncArgs...);
604
605 public:
CallableWrapperUnpackArgs(Fn fn,const char * name)606 CallableWrapperUnpackArgs(Fn fn, const char* name)
607 : CallableWrapper<Fn, TReturn, TFuncArgs...>(fn, name) {}
608 };
609
610 // This is the second stage of extracting the arguments from lambda function.
611 // NOTE: CallableWrapper's first template argument is the type of the
612 // function or functor (not the member pointer).
613 template <typename TClass, typename TReturn, typename... TFuncArgs>
614 class CallableWrapperUnpackArgs<TReturn (TClass::*)(TFuncArgs...) const>
615 : public CallableWrapper<TClass, TReturn, TFuncArgs...> {
616 using Fn = TClass;
617
618 public:
CallableWrapperUnpackArgs(Fn fn,const char * name)619 CallableWrapperUnpackArgs(Fn fn, const char* name)
620 : CallableWrapper<Fn, TReturn, TFuncArgs...>(fn, name) {}
621 };
622
623 template <class Fn, typename TReturn, class... ArgsOut>
624 class UneraseCallHelper;
625
626 // UneraseCallHelper::Call allows transforming all the incoming arguments
627 // from a TaggedValue tuple to a variadic list of args. The class template
628 // starts as a list of argument types and ends empty. The static member
629 // template starts empty and ends with the unerased types of the signature.
630
631 // Base case (all arguments are processed, so call the function TFunc.
632 template <class Fn, typename TReturn>
633 class UneraseCallHelper<Fn, TReturn> {
634 public:
635 template <typename... ArgsOut>
Call(const char * name,Fn functor_,int argument_index,const TaggedValue & args_in,ArgsOut...args)636 static tensorflow::StatusOr<TaggedValue> Call(const char* name, Fn functor_,
637 int argument_index,
638 const TaggedValue& args_in,
639 ArgsOut... args) {
640 // Call concrete type function
641 TReturn ret = functor_(args...);
642 return ret.value_;
643 }
644 };
645
646 // Unpack a single argument case. Each argument is then cast.
647 template <class Fn, typename TReturn, class TSignatureArg,
648 class... TSignatureRest>
649 class UneraseCallHelper<Fn, TReturn, TSignatureArg, TSignatureRest...> {
650 public:
651 template <typename... TArgsOut>
Call(const char * name,Fn fn,int argument_index,TaggedValue & args_in,TArgsOut...args)652 static tensorflow::StatusOr<TaggedValue> Call(const char* name, Fn fn,
653 int argument_index,
654 TaggedValue& args_in,
655 TArgsOut... args) {
656 Handle h(std::move(args_in.tuple()[argument_index]));
657 tensorflow::StatusOr<TSignatureArg> x = Cast<TSignatureArg>(std::move(h));
658 if (!x.ok())
659 return tensorflow::errors::InvalidArgument(
660 std::string("Function ") + name + " Arg " +
661 std::to_string(argument_index) +
662 " cannot be cast to desired signature type ");
663 return UneraseCallHelper<Fn, TReturn, TSignatureRest...>::template Call(
664 name, fn, argument_index + 1, args_in, args..., *x);
665 }
666 };
667
668 // Template specialization that allows extracting arguments from a C function
669 // pointer.
670 template <class Fn, typename TReturn, typename... TFuncArgs>
671 class CallableWrapper {
672 private:
673 Fn functor_;
674 const char* name_;
675
676 public:
CallableWrapper(Fn fn,const char * name)677 explicit CallableWrapper(Fn fn, const char* name)
678 : functor_(fn), name_(name) {}
679
680 // Entry point of the Adaptor functor. Note args, and kwargs are attempted
681 // to be moved.
operator()682 tensorflow::StatusOr<TaggedValue> operator()(TaggedValue args,
683 TaggedValue kwargs) {
684 constexpr size_t argument_count = sizeof...(TFuncArgs);
685 if (argument_count != args.tuple().size())
686 return tensorflow::errors::InvalidArgument(
687 std::string("Function ") + name_ + " expected " +
688 std::to_string(argument_count) + " args.");
689 return UneraseCallHelper<Fn, TReturn, TFuncArgs...>::Call(name_, functor_,
690 0, args);
691 }
692 };
693
694 // Wrap a function that uses object handles as arguments and return types
695 // with one that takes TaggedValues. For example:
696 // Tuple Pack(Integer, Float, String);
697 // TaggedValue callable = TFLIB_CALLABLE_ADAPTOR(Pack);
698 #define TFLIB_CALLABLE_ADAPTOR(x) ::tf::libtf::CreateCallableAdaptor(x, #x)
699
700 template <class TF>
CreateCallableAdaptor(TF x,const char * name)701 TaggedValue CreateCallableAdaptor(TF x, const char* name) {
702 return TaggedValue((CallableWrapperUnpackArgs<TF>(x, name)));
703 }
704
705 } // namespace libtf
706 } // namespace tf
707
708 #endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_
709