xref: /aosp_15_r20/external/pytorch/c10/util/Metaprogramming.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/TypeList.h>
4*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10::guts {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker /**
9*da0073e9SAndroid Build Coastguard Worker  * Access information about result type or arguments from a function type.
10*da0073e9SAndroid Build Coastguard Worker  * Example:
11*da0073e9SAndroid Build Coastguard Worker  * using A = function_traits<int (float, double)>::return_type // A == int
12*da0073e9SAndroid Build Coastguard Worker  * using A = function_traits<int (float, double)>::parameter_types::tuple_type
13*da0073e9SAndroid Build Coastguard Worker  * // A == tuple<float, double>
14*da0073e9SAndroid Build Coastguard Worker  */
15*da0073e9SAndroid Build Coastguard Worker template <class Func>
16*da0073e9SAndroid Build Coastguard Worker struct function_traits {
17*da0073e9SAndroid Build Coastguard Worker   static_assert(
18*da0073e9SAndroid Build Coastguard Worker       !std::is_same_v<Func, Func>,
19*da0073e9SAndroid Build Coastguard Worker       "In function_traits<Func>, Func must be a plain function type.");
20*da0073e9SAndroid Build Coastguard Worker };
21*da0073e9SAndroid Build Coastguard Worker template <class Result, class... Args>
22*da0073e9SAndroid Build Coastguard Worker struct function_traits<Result(Args...)> {
23*da0073e9SAndroid Build Coastguard Worker   using func_type = Result(Args...);
24*da0073e9SAndroid Build Coastguard Worker   using return_type = Result;
25*da0073e9SAndroid Build Coastguard Worker   using parameter_types = typelist::typelist<Args...>;
26*da0073e9SAndroid Build Coastguard Worker   static constexpr auto number_of_parameters = sizeof...(Args);
27*da0073e9SAndroid Build Coastguard Worker };
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker /**
30*da0073e9SAndroid Build Coastguard Worker  * infer_function_traits: creates a `function_traits` type for a simple
31*da0073e9SAndroid Build Coastguard Worker  * function (pointer) or functor (lambda/struct). Currently does not support
32*da0073e9SAndroid Build Coastguard Worker  * class methods.
33*da0073e9SAndroid Build Coastguard Worker  */
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker template <typename Functor>
36*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits {
37*da0073e9SAndroid Build Coastguard Worker   using type = function_traits<
38*da0073e9SAndroid Build Coastguard Worker       c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
39*da0073e9SAndroid Build Coastguard Worker };
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args>
42*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits<Result (*)(Args...)> {
43*da0073e9SAndroid Build Coastguard Worker   using type = function_traits<Result(Args...)>;
44*da0073e9SAndroid Build Coastguard Worker };
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args>
47*da0073e9SAndroid Build Coastguard Worker struct infer_function_traits<Result(Args...)> {
48*da0073e9SAndroid Build Coastguard Worker   using type = function_traits<Result(Args...)>;
49*da0073e9SAndroid Build Coastguard Worker };
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker template <typename T>
52*da0073e9SAndroid Build Coastguard Worker using infer_function_traits_t = typename infer_function_traits<T>::type;
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker /**
55*da0073e9SAndroid Build Coastguard Worker  * make_function_traits: creates a `function_traits` type given a Return type
56*da0073e9SAndroid Build Coastguard Worker  * and a typelist of Argument types
57*da0073e9SAndroid Build Coastguard Worker  *
58*da0073e9SAndroid Build Coastguard Worker  * Example:
59*da0073e9SAndroid Build Coastguard Worker  * bool f(int, int);
60*da0073e9SAndroid Build Coastguard Worker  *
61*da0073e9SAndroid Build Coastguard Worker  * infer_function_traits_t<f> == make_function_traits_t<bool,
62*da0073e9SAndroid Build Coastguard Worker  * typelist::typelist<int, int>>
63*da0073e9SAndroid Build Coastguard Worker  */
64*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename ArgList>
65*da0073e9SAndroid Build Coastguard Worker struct make_function_traits {
66*da0073e9SAndroid Build Coastguard Worker   static_assert(
67*da0073e9SAndroid Build Coastguard Worker       false_t<ArgList>::value,
68*da0073e9SAndroid Build Coastguard Worker       "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
69*da0073e9SAndroid Build Coastguard Worker };
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename... Args>
72*da0073e9SAndroid Build Coastguard Worker struct make_function_traits<Result, typelist::typelist<Args...>> {
73*da0073e9SAndroid Build Coastguard Worker   using type = function_traits<Result(Args...)>;
74*da0073e9SAndroid Build Coastguard Worker };
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker template <typename Result, typename ArgList>
77*da0073e9SAndroid Build Coastguard Worker using make_function_traits_t =
78*da0073e9SAndroid Build Coastguard Worker     typename make_function_traits<Result, ArgList>::type;
79*da0073e9SAndroid Build Coastguard Worker 
80*da0073e9SAndroid Build Coastguard Worker /**
81*da0073e9SAndroid Build Coastguard Worker  * make_offset_index_sequence<Start, N>
82*da0073e9SAndroid Build Coastguard Worker  * Like make_index_sequence<N>, but starting from Start instead of 0.
83*da0073e9SAndroid Build Coastguard Worker  *
84*da0073e9SAndroid Build Coastguard Worker  * Example:
85*da0073e9SAndroid Build Coastguard Worker  *  make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
86*da0073e9SAndroid Build Coastguard Worker  */
87*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t N, size_t... Is>
88*da0073e9SAndroid Build Coastguard Worker struct make_offset_index_sequence_impl
89*da0073e9SAndroid Build Coastguard Worker     : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
90*da0073e9SAndroid Build Coastguard Worker   static_assert(
91*da0073e9SAndroid Build Coastguard Worker       static_cast<int>(Start) >= 0,
92*da0073e9SAndroid Build Coastguard Worker       "make_offset_index_sequence: Start < 0");
93*da0073e9SAndroid Build Coastguard Worker   static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
94*da0073e9SAndroid Build Coastguard Worker };
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t... Is>
97*da0073e9SAndroid Build Coastguard Worker struct make_offset_index_sequence_impl<Start, 0, Is...> {
98*da0073e9SAndroid Build Coastguard Worker   typedef std::index_sequence<Is...> type;
99*da0073e9SAndroid Build Coastguard Worker };
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker template <size_t Start, size_t N>
102*da0073e9SAndroid Build Coastguard Worker using make_offset_index_sequence =
103*da0073e9SAndroid Build Coastguard Worker     typename make_offset_index_sequence_impl<Start, N>::type;
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker /**
106*da0073e9SAndroid Build Coastguard Worker  * Use tuple_elements to extract a position-indexed subset of elements
107*da0073e9SAndroid Build Coastguard Worker  * from the argument tuple into a result tuple.
108*da0073e9SAndroid Build Coastguard Worker  *
109*da0073e9SAndroid Build Coastguard Worker  * Example:
110*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
111*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
112*da0073e9SAndroid Build Coastguard Worker  * 2>());
113*da0073e9SAndroid Build Coastguard Worker  */
114*da0073e9SAndroid Build Coastguard Worker template <class Tuple, size_t... Is>
115*da0073e9SAndroid Build Coastguard Worker constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
116*da0073e9SAndroid Build Coastguard Worker   return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker 
119*da0073e9SAndroid Build Coastguard Worker /**
120*da0073e9SAndroid Build Coastguard Worker  * Use tuple_take to extract the first or last n elements from the argument
121*da0073e9SAndroid Build Coastguard Worker  * tuple into a result tuple.
122*da0073e9SAndroid Build Coastguard Worker  *
123*da0073e9SAndroid Build Coastguard Worker  * Example:
124*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
125*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
126*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
127*da0073e9SAndroid Build Coastguard Worker  */
128*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N, class Enable = void>
129*da0073e9SAndroid Build Coastguard Worker struct TupleTake {};
130*da0073e9SAndroid Build Coastguard Worker 
131*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N>
132*da0073e9SAndroid Build Coastguard Worker struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
133*da0073e9SAndroid Build Coastguard Worker   static auto call(Tuple t) {
134*da0073e9SAndroid Build Coastguard Worker     constexpr size_t size = std::tuple_size<Tuple>();
135*da0073e9SAndroid Build Coastguard Worker     static_assert(N <= size, "tuple_take: N > size");
136*da0073e9SAndroid Build Coastguard Worker     return tuple_elements(t, std::make_index_sequence<N>{});
137*da0073e9SAndroid Build Coastguard Worker   }
138*da0073e9SAndroid Build Coastguard Worker };
139*da0073e9SAndroid Build Coastguard Worker 
140*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N>
141*da0073e9SAndroid Build Coastguard Worker     struct TupleTake < Tuple,
142*da0073e9SAndroid Build Coastguard Worker     N, std::enable_if_t<N<0, void>> {
143*da0073e9SAndroid Build Coastguard Worker   static auto call(Tuple t) {
144*da0073e9SAndroid Build Coastguard Worker     constexpr size_t size = std::tuple_size<Tuple>();
145*da0073e9SAndroid Build Coastguard Worker     static_assert(-N <= size, "tuple_take: -N > size");
146*da0073e9SAndroid Build Coastguard Worker     return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
147*da0073e9SAndroid Build Coastguard Worker   }
148*da0073e9SAndroid Build Coastguard Worker };
149*da0073e9SAndroid Build Coastguard Worker 
150*da0073e9SAndroid Build Coastguard Worker template <class Tuple, int N>
151*da0073e9SAndroid Build Coastguard Worker auto tuple_take(Tuple t) {
152*da0073e9SAndroid Build Coastguard Worker   return TupleTake<Tuple, N>::call(t);
153*da0073e9SAndroid Build Coastguard Worker }
154*da0073e9SAndroid Build Coastguard Worker 
155*da0073e9SAndroid Build Coastguard Worker /**
156*da0073e9SAndroid Build Coastguard Worker  * Use tuple_slice to extract a contiguous subtuple from the argument.
157*da0073e9SAndroid Build Coastguard Worker  *
158*da0073e9SAndroid Build Coastguard Worker  * Example:
159*da0073e9SAndroid Build Coastguard Worker  *  std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
160*da0073e9SAndroid Build Coastguard Worker  * "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
161*da0073e9SAndroid Build Coastguard Worker  * tuple_slice<decltype(t), 1, 2>(t);
162*da0073e9SAndroid Build Coastguard Worker  */
163*da0073e9SAndroid Build Coastguard Worker template <class Tuple, size_t Start, size_t N>
164*da0073e9SAndroid Build Coastguard Worker constexpr auto tuple_slice(Tuple t) {
165*da0073e9SAndroid Build Coastguard Worker   constexpr size_t size = std::tuple_size<Tuple>();
166*da0073e9SAndroid Build Coastguard Worker   static_assert(Start + N <= size, "tuple_slice: Start + N > size");
167*da0073e9SAndroid Build Coastguard Worker   return tuple_elements(t, make_offset_index_sequence<Start, N>{});
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker 
170*da0073e9SAndroid Build Coastguard Worker /**
171*da0073e9SAndroid Build Coastguard Worker  * Use tuple_map to run a mapping function over a tuple to get a new tuple.
172*da0073e9SAndroid Build Coastguard Worker  *
173*da0073e9SAndroid Build Coastguard Worker  * Example 1:
174*da0073e9SAndroid Build Coastguard Worker  *   auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
175*da0073e9SAndroid Build Coastguard Worker  * (int32_t a) -> int16_t {return a+1;});
176*da0073e9SAndroid Build Coastguard Worker  *   // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
177*da0073e9SAndroid Build Coastguard Worker  *
178*da0073e9SAndroid Build Coastguard Worker  * Example 2:
179*da0073e9SAndroid Build Coastguard Worker  *   struct Mapper {
180*da0073e9SAndroid Build Coastguard Worker  *     std::string operator()(int32_t a) const {
181*da0073e9SAndroid Build Coastguard Worker  *       return std::to_string(a);
182*da0073e9SAndroid Build Coastguard Worker  *     }
183*da0073e9SAndroid Build Coastguard Worker  *     int64_t operator()(const std::string& a) const {
184*da0073e9SAndroid Build Coastguard Worker  *        return atoi(a.c_str());
185*da0073e9SAndroid Build Coastguard Worker  *     }
186*da0073e9SAndroid Build Coastguard Worker  *   };
187*da0073e9SAndroid Build Coastguard Worker  *   auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
188*da0073e9SAndroid Build Coastguard Worker  * Mapper());
189*da0073e9SAndroid Build Coastguard Worker  *   // result == std::tuple<std::string, int64_t>("3", 4)
190*da0073e9SAndroid Build Coastguard Worker  *
191*da0073e9SAndroid Build Coastguard Worker  * Example 3:
192*da0073e9SAndroid Build Coastguard Worker  *   struct A final {
193*da0073e9SAndroid Build Coastguard Worker  *    int32_t func() {
194*da0073e9SAndroid Build Coastguard Worker  *      return 5;
195*da0073e9SAndroid Build Coastguard Worker  *    }
196*da0073e9SAndroid Build Coastguard Worker  *  };
197*da0073e9SAndroid Build Coastguard Worker  *  struct B final {
198*da0073e9SAndroid Build Coastguard Worker  *    std::string func() {
199*da0073e9SAndroid Build Coastguard Worker  *      return "5";
200*da0073e9SAndroid Build Coastguard Worker  *    }
201*da0073e9SAndroid Build Coastguard Worker  *  };
202*da0073e9SAndroid Build Coastguard Worker  *  auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
203*da0073e9SAndroid Build Coastguard Worker  * a.func(); });
204*da0073e9SAndroid Build Coastguard Worker  *  // result == std::tuple<int32_t, std::string>(5, "5");
205*da0073e9SAndroid Build Coastguard Worker  */
206*da0073e9SAndroid Build Coastguard Worker namespace detail {
207*da0073e9SAndroid Build Coastguard Worker template <class Mapper, class... Args, size_t... Indices>
208*da0073e9SAndroid Build Coastguard Worker auto tuple_map(
209*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
210*da0073e9SAndroid Build Coastguard Worker     std::tuple<Args...>&& tuple,
211*da0073e9SAndroid Build Coastguard Worker     const Mapper& mapper,
212*da0073e9SAndroid Build Coastguard Worker     std::index_sequence<Indices...>) {
213*da0073e9SAndroid Build Coastguard Worker   return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
214*da0073e9SAndroid Build Coastguard Worker       tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
215*da0073e9SAndroid Build Coastguard Worker }
216*da0073e9SAndroid Build Coastguard Worker } // namespace detail
217*da0073e9SAndroid Build Coastguard Worker 
218*da0073e9SAndroid Build Coastguard Worker template <class Mapper, class... Args>
219*da0073e9SAndroid Build Coastguard Worker auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
220*da0073e9SAndroid Build Coastguard Worker   return detail::tuple_map(
221*da0073e9SAndroid Build Coastguard Worker       std::move(tuple), mapper, std::index_sequence_for<Args...>());
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker 
224*da0073e9SAndroid Build Coastguard Worker } // namespace c10::guts
225