xref: /aosp_15_r20/external/executorch/extension/kernel_util/meta_programming.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 #if __cplusplus < 201703L
11 #error "This header requires C++17"
12 #endif
13 
14 #include <executorch/extension/kernel_util/type_list.h>
15 #include <cstdlib>
16 #include <memory>
17 #include <type_traits>
18 #include <typeinfo>
19 
20 namespace executorch {
21 namespace extension {
22 // This extension has a lot of generic internal names like "size"; use a unique
23 // internal namespace to avoid conflicts with other extensions.
24 namespace kernel_util_internal {
25 
26 // Check if a given type is a function
27 template <class T>
28 struct is_function_type : std::false_type {};
29 template <class Result, class... Args>
30 struct is_function_type<Result(Args...)> : std::true_type {};
31 template <class T>
32 using is_function_type_t = typename is_function_type<T>::type;
33 
34 // A compile-time wrapper around a function pointer
35 template <class FuncType_, FuncType_* func_ptr_>
36 struct CompileTimeFunctionPointer final {
37   static_assert(
38       is_function_type<FuncType_>::value,
39       "EXECUTORCH_FN can only wrap function types.");
40   using FuncType = FuncType_;
41 
42   static constexpr FuncType* func_ptr() {
43     return func_ptr_;
44   }
45 };
46 
47 // Check if a given type is a compile-time function pointer
48 template <class T>
49 struct is_compile_time_function_pointer : std::false_type {};
50 template <class FuncType, FuncType* func_ptr>
51 struct is_compile_time_function_pointer<
52     CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
53 
54 #define EXECUTORCH_FN_TYPE(func)                                             \
55   ::executorch::extension::kernel_util_internal::CompileTimeFunctionPointer< \
56       std::remove_pointer_t<std::remove_reference_t<decltype(func)>>,        \
57       func>
58 #define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)()
59 
60 /**
61  * strip_class: helper to remove the class type from pointers to `operator()`.
62  */
63 template <typename T>
64 struct strip_class {};
65 template <typename Class, typename Result, typename... Args>
66 struct strip_class<Result (Class::*)(Args...)> {
67   using type = Result(Args...);
68 };
69 template <typename Class, typename Result, typename... Args>
70 struct strip_class<Result (Class::*)(Args...) const> {
71   using type = Result(Args...);
72 };
73 template <typename T>
74 using strip_class_t = typename strip_class<T>::type;
75 
76 /**
77  * Access information about result type or arguments from a function type.
78  * Example:
79  * using A = function_traits<int (float, double)>::return_type // A == int
80  * using A = function_traits<int (float, double)>::parameter_types::tuple_type
81  * // A == tuple<float, double>
82  */
83 template <class Func>
84 struct function_traits {
85   static_assert(
86       !std::is_same<Func, Func>::value,
87       "In function_traits<Func>, Func must be a plain function type.");
88 };
89 template <class Result, class... Args>
90 struct function_traits<Result(Args...)> {
91   using func_type = Result(Args...);
92   using return_type = Result;
93   using parameter_types = typelist<Args...>;
94   static constexpr auto number_of_parameters = sizeof...(Args);
95 };
96 
97 /**
98  * infer_function_traits: creates a `function_traits` type for a simple
99  * function (pointer) or functor (lambda/struct). Currently does not support
100  * class methods.
101  */
102 template <typename Functor>
103 struct infer_function_traits {
104   using type = function_traits<strip_class_t<decltype(&Functor::operator())>>;
105 };
106 template <typename Result, typename... Args>
107 struct infer_function_traits<Result (*)(Args...)> {
108   using type = function_traits<Result(Args...)>;
109 };
110 template <typename Result, typename... Args>
111 struct infer_function_traits<Result(Args...)> {
112   using type = function_traits<Result(Args...)>;
113 };
114 template <typename T>
115 using infer_function_traits_t = typename infer_function_traits<T>::type;
116 
117 } // namespace kernel_util_internal
118 } // namespace extension
119 } // namespace executorch
120