xref: /aosp_15_r20/external/pytorch/c10/core/CompileTimeFunctionPointer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/TypeTraits.h>
4 #include <type_traits>
5 
6 namespace c10 {
7 
8 /**
9  * Represent a function pointer as a C++ type.
10  * This allows using the function pointer as a type
11  * in a template and calling it from inside the template
12  * allows the compiler to inline the call because it
13  * knows the function pointer at compile time.
14  *
15  * Example 1:
16  *  int add(int a, int b) {return a + b;}
17  *  using Add = TORCH_FN_TYPE(add);
18  *  template<class Func> struct Executor {
19  *    int execute(int a, int b) {
20  *      return Func::func_ptr()(a, b);
21  *    }
22  *  };
23  *  Executor<Add> executor;
24  *  EXPECT_EQ(3, executor.execute(1, 2));
25  *
26  * Example 2:
27  *  int add(int a, int b) {return a + b;}
28  *  template<class Func> int execute(Func, int a, int b) {
29  *    return Func::func_ptr()(a, b);
30  *  }
31  *  EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
32  */
33 template <class FuncType_, FuncType_* func_ptr_>
34 struct CompileTimeFunctionPointer final {
35   static_assert(
36       guts::is_function_type<FuncType_>::value,
37       "TORCH_FN can only wrap function types.");
38   using FuncType = FuncType_;
39 
func_ptrfinal40   static constexpr FuncType* func_ptr() {
41     return func_ptr_;
42   }
43 };
44 
45 template <class T>
46 struct is_compile_time_function_pointer : std::false_type {};
47 template <class FuncType, FuncType* func_ptr>
48 struct is_compile_time_function_pointer<
49     CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
50 
51 } // namespace c10
52 
53 #define TORCH_FN_TYPE(func)                                           \
54   ::c10::CompileTimeFunctionPointer<                                  \
55       std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
56       func>
57 #define TORCH_FN(func) TORCH_FN_TYPE(func)()
58