1 #pragma once 2 3 #include <c10/util/TypeList.h> 4 #include <type_traits> 5 6 namespace c10::guts { 7 8 /** 9 * Access information about result type or arguments from a function type. 10 * Example: 11 * using A = function_traits<int (float, double)>::return_type // A == int 12 * using A = function_traits<int (float, double)>::parameter_types::tuple_type 13 * // A == tuple<float, double> 14 */ 15 template <class Func> 16 struct function_traits { 17 static_assert( 18 !std::is_same_v<Func, Func>, 19 "In function_traits<Func>, Func must be a plain function type."); 20 }; 21 template <class Result, class... Args> 22 struct function_traits<Result(Args...)> { 23 using func_type = Result(Args...); 24 using return_type = Result; 25 using parameter_types = typelist::typelist<Args...>; 26 static constexpr auto number_of_parameters = sizeof...(Args); 27 }; 28 29 /** 30 * infer_function_traits: creates a `function_traits` type for a simple 31 * function (pointer) or functor (lambda/struct). Currently does not support 32 * class methods. 33 */ 34 35 template <typename Functor> 36 struct infer_function_traits { 37 using type = function_traits< 38 c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>; 39 }; 40 41 template <typename Result, typename... Args> 42 struct infer_function_traits<Result (*)(Args...)> { 43 using type = function_traits<Result(Args...)>; 44 }; 45 46 template <typename Result, typename... Args> 47 struct infer_function_traits<Result(Args...)> { 48 using type = function_traits<Result(Args...)>; 49 }; 50 51 template <typename T> 52 using infer_function_traits_t = typename infer_function_traits<T>::type; 53 54 /** 55 * make_function_traits: creates a `function_traits` type given a Return type 56 * and a typelist of Argument types 57 * 58 * Example: 59 * bool f(int, int); 60 * 61 * infer_function_traits_t<f> == make_function_traits_t<bool, 62 * typelist::typelist<int, int>> 63 */ 64 template <typename Result, typename ArgList> 65 struct make_function_traits { 66 static_assert( 67 false_t<ArgList>::value, 68 "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>."); 69 }; 70 71 template <typename Result, typename... Args> 72 struct make_function_traits<Result, typelist::typelist<Args...>> { 73 using type = function_traits<Result(Args...)>; 74 }; 75 76 template <typename Result, typename ArgList> 77 using make_function_traits_t = 78 typename make_function_traits<Result, ArgList>::type; 79 80 /** 81 * make_offset_index_sequence<Start, N> 82 * Like make_index_sequence<N>, but starting from Start instead of 0. 83 * 84 * Example: 85 * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> 86 */ 87 template <size_t Start, size_t N, size_t... Is> 88 struct make_offset_index_sequence_impl 89 : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> { 90 static_assert( 91 static_cast<int>(Start) >= 0, 92 "make_offset_index_sequence: Start < 0"); 93 static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0"); 94 }; 95 96 template <size_t Start, size_t... Is> 97 struct make_offset_index_sequence_impl<Start, 0, Is...> { 98 typedef std::index_sequence<Is...> type; 99 }; 100 101 template <size_t Start, size_t N> 102 using make_offset_index_sequence = 103 typename make_offset_index_sequence_impl<Start, N>::type; 104 105 /** 106 * Use tuple_elements to extract a position-indexed subset of elements 107 * from the argument tuple into a result tuple. 108 * 109 * Example: 110 * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); 111 * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0, 112 * 2>()); 113 */ 114 template <class Tuple, size_t... Is> 115 constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) { 116 return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...); 117 } 118 119 /** 120 * Use tuple_take to extract the first or last n elements from the argument 121 * tuple into a result tuple. 122 * 123 * Example: 124 * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); 125 * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t); 126 * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t); 127 */ 128 template <class Tuple, int N, class Enable = void> 129 struct TupleTake {}; 130 131 template <class Tuple, int N> 132 struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> { 133 static auto call(Tuple t) { 134 constexpr size_t size = std::tuple_size<Tuple>(); 135 static_assert(N <= size, "tuple_take: N > size"); 136 return tuple_elements(t, std::make_index_sequence<N>{}); 137 } 138 }; 139 140 template <class Tuple, int N> 141 struct TupleTake < Tuple, 142 N, std::enable_if_t<N<0, void>> { 143 static auto call(Tuple t) { 144 constexpr size_t size = std::tuple_size<Tuple>(); 145 static_assert(-N <= size, "tuple_take: -N > size"); 146 return tuple_elements(t, make_offset_index_sequence<size + N, -N>{}); 147 } 148 }; 149 150 template <class Tuple, int N> 151 auto tuple_take(Tuple t) { 152 return TupleTake<Tuple, N>::call(t); 153 } 154 155 /** 156 * Use tuple_slice to extract a contiguous subtuple from the argument. 157 * 158 * Example: 159 * std::tuple<int, const char*, double, bool> t = std::make_tuple(0, 160 * "HEY", 2.0, false); std::tuple<int, const char*> middle_two = 161 * tuple_slice<decltype(t), 1, 2>(t); 162 */ 163 template <class Tuple, size_t Start, size_t N> 164 constexpr auto tuple_slice(Tuple t) { 165 constexpr size_t size = std::tuple_size<Tuple>(); 166 static_assert(Start + N <= size, "tuple_slice: Start + N > size"); 167 return tuple_elements(t, make_offset_index_sequence<Start, N>{}); 168 } 169 170 /** 171 * Use tuple_map to run a mapping function over a tuple to get a new tuple. 172 * 173 * Example 1: 174 * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), [] 175 * (int32_t a) -> int16_t {return a+1;}); 176 * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6) 177 * 178 * Example 2: 179 * struct Mapper { 180 * std::string operator()(int32_t a) const { 181 * return std::to_string(a); 182 * } 183 * int64_t operator()(const std::string& a) const { 184 * return atoi(a.c_str()); 185 * } 186 * }; 187 * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"), 188 * Mapper()); 189 * // result == std::tuple<std::string, int64_t>("3", 4) 190 * 191 * Example 3: 192 * struct A final { 193 * int32_t func() { 194 * return 5; 195 * } 196 * }; 197 * struct B final { 198 * std::string func() { 199 * return "5"; 200 * } 201 * }; 202 * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return 203 * a.func(); }); 204 * // result == std::tuple<int32_t, std::string>(5, "5"); 205 */ 206 namespace detail { 207 template <class Mapper, class... Args, size_t... Indices> 208 auto tuple_map( 209 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) 210 std::tuple<Args...>&& tuple, 211 const Mapper& mapper, 212 std::index_sequence<Indices...>) { 213 return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>( 214 tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...); 215 } 216 } // namespace detail 217 218 template <class Mapper, class... Args> 219 auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) { 220 return detail::tuple_map( 221 std::move(tuple), mapper, std::index_sequence_for<Args...>()); 222 } 223 224 } // namespace c10::guts 225