1 #pragma once 2 3 /** 4 * This file contains functionality to take a C++ function and infer its 5 * c10::FunctionSchema. 6 */ 7 8 #include <ATen/core/function_schema.h> 9 #include <c10/util/Metaprogramming.h> 10 11 namespace c10 { 12 namespace detail { 13 14 namespace infer_schema { 15 16 /// The templated inference code creates `ArgumentDef` instead of `Argument`, 17 /// because that can be constructed at compile time and has a much smaller 18 /// binary size than having calls to `Argument` constructors in the template. 19 /// Creating `Argument` objects from `ArgumentDef` can then be done at 20 /// runtime in a non-templated way. 21 struct ArgumentDef final { 22 using GetTypeFn = TypePtr(); 23 GetTypeFn* getTypeFn; 24 GetTypeFn* getFakeTypeFn; ArgumentDeffinal25 constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {} ArgumentDeffinal26 explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {} 27 }; 28 29 template<bool V> 30 struct bool_t {}; 31 template<> struct bool_t<true> : std::true_type {}; 32 template<> struct bool_t<false> : std::false_type {}; 33 34 /// Checks the static C++ types `Types` for correctness to catch common error cases. 35 template <class... Types> 36 constexpr int checkStaticTypes() { 37 // Give nice error messages for some of the common error cases. 38 // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT 39 static_assert(std::conjunction< 40 bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>... 41 >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type"); 42 static_assert(std::conjunction< 43 bool_t<!std::is_same<Types, float>::value>... 44 >::value, "INVALID TYPE: float is not supported as an argument type, use double instead"); 45 return 0; 46 } 47 48 template <typename... Ts, size_t... Is> 49 constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...>) { 50 return ( 51 // Check types for common errors 52 checkStaticTypes<Ts...>(), 53 54 // Create the return value 55 std::array<ArgumentDef, sizeof...(Ts)>{ 56 ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...} 57 ); 58 } 59 60 /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified 61 /// as template arguments. 62 template<class ParameterTypes> struct createArguments final {}; 63 template<class... ParameterTypes> 64 struct createArguments<guts::typelist::typelist<ParameterTypes...>> final { 65 static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() { 66 return createArgumentVectorFromTypes<ParameterTypes...>( 67 std::make_index_sequence<sizeof...(ParameterTypes)>() 68 ); 69 } 70 }; 71 72 /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified 73 /// as a tuple (i.e. in the way c10 kernels return values). 74 /// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C. 75 /// It can be an empty tuple<>, or void for kernels that don't return anything. 76 /// It can be a single type A (i.e. no tuple) for the case where a kernel just 77 /// returns one value. 78 template<class ReturnTypeTuple, class Enable = void> struct createReturns final {}; 79 80 template<class... ReturnTypes> 81 struct createReturns<std::tuple<ReturnTypes...>, void> final { 82 static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() { 83 return createArgumentVectorFromTypes<ReturnTypes...>( 84 std::make_index_sequence<sizeof...(ReturnTypes)>() 85 ); 86 } 87 }; 88 89 template<class ReturnType> 90 struct createReturns<ReturnType, std::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final { 91 static constexpr std::array<ArgumentDef, 1> call() { 92 return createReturns<std::tuple<ReturnType>>::call(); 93 } 94 }; 95 96 template<> 97 struct createReturns<void, void> final { 98 static constexpr std::array<ArgumentDef, 0> call() { 99 return createReturns<std::tuple<>>::call(); 100 } 101 }; 102 103 template <typename ReturnType> 104 struct createSingleReturn { 105 static constexpr std::array<ArgumentDef, 1> call() { 106 return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>()); 107 } 108 }; 109 110 TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns); 111 TORCH_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns); 112 113 /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a 114 /// function. Flattens std::tuple returns into multiple return types 115 template <typename FunctionTraits> 116 FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() { 117 using ReturnType = typename FunctionTraits::return_type; 118 using ParameterTypes = typename FunctionTraits::parameter_types; 119 120 // arguments and returns are computed into a std::array at compile time and embedded into the binary. 121 // The only code executed at runtime here is the one that creates a std::vector 122 // of the arguments/returns from the std::array. 123 constexpr auto arguments = createArguments<ParameterTypes>::call(); 124 constexpr auto returns = createReturns<ReturnType>::call(); 125 126 return make_function_schema(arguments, returns); 127 } 128 129 /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a 130 /// function. Preserves std::tuple returns as a Tuple return type 131 template <typename FunctionTraits> 132 FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) { 133 using ReturnType = typename FunctionTraits::return_type; 134 using ParameterTypes = typename FunctionTraits::parameter_types; 135 136 // arguments and returns are computed into a std::array at compile time and embedded into the binary. 137 // The only code executed at runtime here is the one that creates a std::vector 138 // of the arguments/returns from the std::array. 139 constexpr auto arguments = createArguments<ParameterTypes>::call(); 140 constexpr auto returns = createSingleReturn<ReturnType>::call(); 141 142 return make_function_schema(std::move(name), std::move(overload_name), arguments, returns); 143 } 144 145 } 146 } 147 148 template<class FuncType> 149 FunctionSchema inferFunctionSchemaFlattenedReturns() { 150 return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>(); 151 } 152 153 template<class FuncType> 154 FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) { 155 return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name)); 156 } 157 158 TORCH_API std::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified); 159 160 } 161