1 #pragma once 2 3 #include <c10/util/TypeTraits.h> 4 #include <type_traits> 5 6 namespace c10 { 7 8 /** 9 * Represent a function pointer as a C++ type. 10 * This allows using the function pointer as a type 11 * in a template and calling it from inside the template 12 * allows the compiler to inline the call because it 13 * knows the function pointer at compile time. 14 * 15 * Example 1: 16 * int add(int a, int b) {return a + b;} 17 * using Add = TORCH_FN_TYPE(add); 18 * template<class Func> struct Executor { 19 * int execute(int a, int b) { 20 * return Func::func_ptr()(a, b); 21 * } 22 * }; 23 * Executor<Add> executor; 24 * EXPECT_EQ(3, executor.execute(1, 2)); 25 * 26 * Example 2: 27 * int add(int a, int b) {return a + b;} 28 * template<class Func> int execute(Func, int a, int b) { 29 * return Func::func_ptr()(a, b); 30 * } 31 * EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2)); 32 */ 33 template <class FuncType_, FuncType_* func_ptr_> 34 struct CompileTimeFunctionPointer final { 35 static_assert( 36 guts::is_function_type<FuncType_>::value, 37 "TORCH_FN can only wrap function types."); 38 using FuncType = FuncType_; 39 func_ptrfinal40 static constexpr FuncType* func_ptr() { 41 return func_ptr_; 42 } 43 }; 44 45 template <class T> 46 struct is_compile_time_function_pointer : std::false_type {}; 47 template <class FuncType, FuncType* func_ptr> 48 struct is_compile_time_function_pointer< 49 CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {}; 50 51 } // namespace c10 52 53 #define TORCH_FN_TYPE(func) \ 54 ::c10::CompileTimeFunctionPointer< \ 55 std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \ 56 func> 57 #define TORCH_FN(func) TORCH_FN_TYPE(func)() 58