1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeList.h> 4*da0073e9SAndroid Build Coastguard Worker #include <type_traits> 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker namespace c10::guts { 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker /** 9*da0073e9SAndroid Build Coastguard Worker * Access information about result type or arguments from a function type. 10*da0073e9SAndroid Build Coastguard Worker * Example: 11*da0073e9SAndroid Build Coastguard Worker * using A = function_traits<int (float, double)>::return_type // A == int 12*da0073e9SAndroid Build Coastguard Worker * using A = function_traits<int (float, double)>::parameter_types::tuple_type 13*da0073e9SAndroid Build Coastguard Worker * // A == tuple<float, double> 14*da0073e9SAndroid Build Coastguard Worker */ 15*da0073e9SAndroid Build Coastguard Worker template <class Func> 16*da0073e9SAndroid Build Coastguard Worker struct function_traits { 17*da0073e9SAndroid Build Coastguard Worker static_assert( 18*da0073e9SAndroid Build Coastguard Worker !std::is_same_v<Func, Func>, 19*da0073e9SAndroid Build Coastguard Worker "In function_traits<Func>, Func must be a plain function type."); 20*da0073e9SAndroid Build Coastguard Worker }; 21*da0073e9SAndroid Build Coastguard Worker template <class Result, class... Args> 22*da0073e9SAndroid Build Coastguard Worker struct function_traits<Result(Args...)> { 23*da0073e9SAndroid Build Coastguard Worker using func_type = Result(Args...); 24*da0073e9SAndroid Build Coastguard Worker using return_type = Result; 25*da0073e9SAndroid Build Coastguard Worker using parameter_types = typelist::typelist<Args...>; 26*da0073e9SAndroid Build Coastguard Worker static constexpr auto number_of_parameters = sizeof...(Args); 27*da0073e9SAndroid Build Coastguard Worker }; 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker /** 30*da0073e9SAndroid Build Coastguard Worker * infer_function_traits: creates a `function_traits` type for a simple 31*da0073e9SAndroid Build Coastguard Worker * function (pointer) or functor (lambda/struct). Currently does not support 32*da0073e9SAndroid Build Coastguard Worker * class methods. 33*da0073e9SAndroid Build Coastguard Worker */ 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker template <typename Functor> 36*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits { 37*da0073e9SAndroid Build Coastguard Worker using type = function_traits< 38*da0073e9SAndroid Build Coastguard Worker c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>; 39*da0073e9SAndroid Build Coastguard Worker }; 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args> 42*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits<Result (*)(Args...)> { 43*da0073e9SAndroid Build Coastguard Worker using type = function_traits<Result(Args...)>; 44*da0073e9SAndroid Build Coastguard Worker }; 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args> 47*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits<Result(Args...)> { 48*da0073e9SAndroid Build Coastguard Worker using type = function_traits<Result(Args...)>; 49*da0073e9SAndroid Build Coastguard Worker }; 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker template <typename T> 52*da0073e9SAndroid Build Coastguard Worker using infer_function_traits_t = typename infer_function_traits<T>::type; 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker /** 55*da0073e9SAndroid Build Coastguard Worker * make_function_traits: creates a `function_traits` type given a Return type 56*da0073e9SAndroid Build Coastguard Worker * and a typelist of Argument types 57*da0073e9SAndroid Build Coastguard Worker * 58*da0073e9SAndroid Build Coastguard Worker * Example: 59*da0073e9SAndroid Build Coastguard Worker * bool f(int, int); 60*da0073e9SAndroid Build Coastguard Worker * 61*da0073e9SAndroid Build Coastguard Worker * infer_function_traits_t<f> == make_function_traits_t<bool, 62*da0073e9SAndroid Build Coastguard Worker * typelist::typelist<int, int>> 63*da0073e9SAndroid Build Coastguard Worker */ 64*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename ArgList> 65*da0073e9SAndroid Build Coastguard Worker struct make_function_traits { 66*da0073e9SAndroid Build Coastguard Worker static_assert( 67*da0073e9SAndroid Build Coastguard Worker false_t<ArgList>::value, 68*da0073e9SAndroid Build Coastguard Worker "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>."); 69*da0073e9SAndroid Build Coastguard Worker }; 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args> 72*da0073e9SAndroid Build Coastguard Worker struct make_function_traits<Result, typelist::typelist<Args...>> { 73*da0073e9SAndroid Build Coastguard Worker using type = function_traits<Result(Args...)>; 74*da0073e9SAndroid Build Coastguard Worker }; 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename ArgList> 77*da0073e9SAndroid Build Coastguard Worker using make_function_traits_t = 78*da0073e9SAndroid Build Coastguard Worker typename make_function_traits<Result, ArgList>::type; 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker /** 81*da0073e9SAndroid Build Coastguard Worker * make_offset_index_sequence<Start, N> 82*da0073e9SAndroid Build Coastguard Worker * Like make_index_sequence<N>, but starting from Start instead of 0. 83*da0073e9SAndroid Build Coastguard Worker * 84*da0073e9SAndroid Build Coastguard Worker * Example: 85*da0073e9SAndroid Build Coastguard Worker * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> 86*da0073e9SAndroid Build Coastguard Worker */ 87*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t N, size_t... Is> 88*da0073e9SAndroid Build Coastguard Worker struct make_offset_index_sequence_impl 89*da0073e9SAndroid Build Coastguard Worker : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> { 90*da0073e9SAndroid Build Coastguard Worker static_assert( 91*da0073e9SAndroid Build Coastguard Worker static_cast<int>(Start) >= 0, 92*da0073e9SAndroid Build Coastguard Worker "make_offset_index_sequence: Start < 0"); 93*da0073e9SAndroid Build Coastguard Worker static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0"); 94*da0073e9SAndroid Build Coastguard Worker }; 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t... Is> 97*da0073e9SAndroid Build Coastguard Worker struct make_offset_index_sequence_impl<Start, 0, Is...> { 98*da0073e9SAndroid Build Coastguard Worker typedef std::index_sequence<Is...> type; 99*da0073e9SAndroid Build Coastguard Worker }; 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t N> 102*da0073e9SAndroid Build Coastguard Worker using make_offset_index_sequence = 103*da0073e9SAndroid Build Coastguard Worker typename make_offset_index_sequence_impl<Start, N>::type; 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker /** 106*da0073e9SAndroid Build Coastguard Worker * Use tuple_elements to extract a position-indexed subset of elements 107*da0073e9SAndroid Build Coastguard Worker * from the argument tuple into a result tuple. 108*da0073e9SAndroid Build Coastguard Worker * 109*da0073e9SAndroid Build Coastguard Worker * Example: 110*da0073e9SAndroid Build Coastguard Worker * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); 111*da0073e9SAndroid Build Coastguard Worker * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0, 112*da0073e9SAndroid Build Coastguard Worker * 2>()); 113*da0073e9SAndroid Build Coastguard Worker */ 114*da0073e9SAndroid Build Coastguard Worker template <class Tuple, size_t... Is> 115*da0073e9SAndroid Build Coastguard Worker constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) { 116*da0073e9SAndroid Build Coastguard Worker return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...); 117*da0073e9SAndroid Build Coastguard Worker } 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker /** 120*da0073e9SAndroid Build Coastguard Worker * Use tuple_take to extract the first or last n elements from the argument 121*da0073e9SAndroid Build Coastguard Worker * tuple into a result tuple. 122*da0073e9SAndroid Build Coastguard Worker * 123*da0073e9SAndroid Build Coastguard Worker * Example: 124*da0073e9SAndroid Build Coastguard Worker * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0); 125*da0073e9SAndroid Build Coastguard Worker * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t); 126*da0073e9SAndroid Build Coastguard Worker * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t); 127*da0073e9SAndroid Build Coastguard Worker */ 128*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N, class Enable = void> 129*da0073e9SAndroid Build Coastguard Worker struct TupleTake {}; 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N> 132*da0073e9SAndroid Build Coastguard Worker struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> { 133*da0073e9SAndroid Build Coastguard Worker static auto call(Tuple t) { 134*da0073e9SAndroid Build Coastguard Worker constexpr size_t size = std::tuple_size<Tuple>(); 135*da0073e9SAndroid Build Coastguard Worker static_assert(N <= size, "tuple_take: N > size"); 136*da0073e9SAndroid Build Coastguard Worker return tuple_elements(t, std::make_index_sequence<N>{}); 137*da0073e9SAndroid Build Coastguard Worker } 138*da0073e9SAndroid Build Coastguard Worker }; 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N> 141*da0073e9SAndroid Build Coastguard Worker struct TupleTake < Tuple, 142*da0073e9SAndroid Build Coastguard Worker N, std::enable_if_t<N<0, void>> { 143*da0073e9SAndroid Build Coastguard Worker static auto call(Tuple t) { 144*da0073e9SAndroid Build Coastguard Worker constexpr size_t size = std::tuple_size<Tuple>(); 145*da0073e9SAndroid Build Coastguard Worker static_assert(-N <= size, "tuple_take: -N > size"); 146*da0073e9SAndroid Build Coastguard Worker return tuple_elements(t, make_offset_index_sequence<size + N, -N>{}); 147*da0073e9SAndroid Build Coastguard Worker } 148*da0073e9SAndroid Build Coastguard Worker }; 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N> 151*da0073e9SAndroid Build Coastguard Worker auto tuple_take(Tuple t) { 152*da0073e9SAndroid Build Coastguard Worker return TupleTake<Tuple, N>::call(t); 153*da0073e9SAndroid Build Coastguard Worker } 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker /** 156*da0073e9SAndroid Build Coastguard Worker * Use tuple_slice to extract a contiguous subtuple from the argument. 157*da0073e9SAndroid Build Coastguard Worker * 158*da0073e9SAndroid Build Coastguard Worker * Example: 159*da0073e9SAndroid Build Coastguard Worker * std::tuple<int, const char*, double, bool> t = std::make_tuple(0, 160*da0073e9SAndroid Build Coastguard Worker * "HEY", 2.0, false); std::tuple<int, const char*> middle_two = 161*da0073e9SAndroid Build Coastguard Worker * tuple_slice<decltype(t), 1, 2>(t); 162*da0073e9SAndroid Build Coastguard Worker */ 163*da0073e9SAndroid Build Coastguard Worker template <class Tuple, size_t Start, size_t N> 164*da0073e9SAndroid Build Coastguard Worker constexpr auto tuple_slice(Tuple t) { 165*da0073e9SAndroid Build Coastguard Worker constexpr size_t size = std::tuple_size<Tuple>(); 166*da0073e9SAndroid Build Coastguard Worker static_assert(Start + N <= size, "tuple_slice: Start + N > size"); 167*da0073e9SAndroid Build Coastguard Worker return tuple_elements(t, make_offset_index_sequence<Start, N>{}); 168*da0073e9SAndroid Build Coastguard Worker } 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker /** 171*da0073e9SAndroid Build Coastguard Worker * Use tuple_map to run a mapping function over a tuple to get a new tuple. 172*da0073e9SAndroid Build Coastguard Worker * 173*da0073e9SAndroid Build Coastguard Worker * Example 1: 174*da0073e9SAndroid Build Coastguard Worker * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), [] 175*da0073e9SAndroid Build Coastguard Worker * (int32_t a) -> int16_t {return a+1;}); 176*da0073e9SAndroid Build Coastguard Worker * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6) 177*da0073e9SAndroid Build Coastguard Worker * 178*da0073e9SAndroid Build Coastguard Worker * Example 2: 179*da0073e9SAndroid Build Coastguard Worker * struct Mapper { 180*da0073e9SAndroid Build Coastguard Worker * std::string operator()(int32_t a) const { 181*da0073e9SAndroid Build Coastguard Worker * return std::to_string(a); 182*da0073e9SAndroid Build Coastguard Worker * } 183*da0073e9SAndroid Build Coastguard Worker * int64_t operator()(const std::string& a) const { 184*da0073e9SAndroid Build Coastguard Worker * return atoi(a.c_str()); 185*da0073e9SAndroid Build Coastguard Worker * } 186*da0073e9SAndroid Build Coastguard Worker * }; 187*da0073e9SAndroid Build Coastguard Worker * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"), 188*da0073e9SAndroid Build Coastguard Worker * Mapper()); 189*da0073e9SAndroid Build Coastguard Worker * // result == std::tuple<std::string, int64_t>("3", 4) 190*da0073e9SAndroid Build Coastguard Worker * 191*da0073e9SAndroid Build Coastguard Worker * Example 3: 192*da0073e9SAndroid Build Coastguard Worker * struct A final { 193*da0073e9SAndroid Build Coastguard Worker * int32_t func() { 194*da0073e9SAndroid Build Coastguard Worker * return 5; 195*da0073e9SAndroid Build Coastguard Worker * } 196*da0073e9SAndroid Build Coastguard Worker * }; 197*da0073e9SAndroid Build Coastguard Worker * struct B final { 198*da0073e9SAndroid Build Coastguard Worker * std::string func() { 199*da0073e9SAndroid Build Coastguard Worker * return "5"; 200*da0073e9SAndroid Build Coastguard Worker * } 201*da0073e9SAndroid Build Coastguard Worker * }; 202*da0073e9SAndroid Build Coastguard Worker * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return 203*da0073e9SAndroid Build Coastguard Worker * a.func(); }); 204*da0073e9SAndroid Build Coastguard Worker * // result == std::tuple<int32_t, std::string>(5, "5"); 205*da0073e9SAndroid Build Coastguard Worker */ 206*da0073e9SAndroid Build Coastguard Worker namespace detail { 207*da0073e9SAndroid Build Coastguard Worker template <class Mapper, class... Args, size_t... Indices> 208*da0073e9SAndroid Build Coastguard Worker auto tuple_map( 209*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) 210*da0073e9SAndroid Build Coastguard Worker std::tuple<Args...>&& tuple, 211*da0073e9SAndroid Build Coastguard Worker const Mapper& mapper, 212*da0073e9SAndroid Build Coastguard Worker std::index_sequence<Indices...>) { 213*da0073e9SAndroid Build Coastguard Worker return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>( 214*da0073e9SAndroid Build Coastguard Worker tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...); 215*da0073e9SAndroid Build Coastguard Worker } 216*da0073e9SAndroid Build Coastguard Worker } // namespace detail 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker template <class Mapper, class... Args> 219*da0073e9SAndroid Build Coastguard Worker auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) { 220*da0073e9SAndroid Build Coastguard Worker return detail::tuple_map( 221*da0073e9SAndroid Build Coastguard Worker std::move(tuple), mapper, std::index_sequence_for<Args...>()); 222*da0073e9SAndroid Build Coastguard Worker } 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker } // namespace c10::guts 225