xref: /aosp_15_r20/external/pytorch/c10/util/Metaprogramming.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/TypeList.h>
4 #include <type_traits>
5 
6 namespace c10::guts {
7 
8 /**
9  * Access information about result type or arguments from a function type.
10  * Example:
11  * using A = function_traits<int (float, double)>::return_type // A == int
12  * using A = function_traits<int (float, double)>::parameter_types::tuple_type
13  * // A == tuple<float, double>
14  */
15 template <class Func>
16 struct function_traits {
17   static_assert(
18       !std::is_same_v<Func, Func>,
19       "In function_traits<Func>, Func must be a plain function type.");
20 };
21 template <class Result, class... Args>
22 struct function_traits<Result(Args...)> {
23   using func_type = Result(Args...);
24   using return_type = Result;
25   using parameter_types = typelist::typelist<Args...>;
26   static constexpr auto number_of_parameters = sizeof...(Args);
27 };
28 
29 /**
30  * infer_function_traits: creates a `function_traits` type for a simple
31  * function (pointer) or functor (lambda/struct). Currently does not support
32  * class methods.
33  */
34 
35 template <typename Functor>
36 struct infer_function_traits {
37   using type = function_traits<
38       c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
39 };
40 
41 template <typename Result, typename... Args>
42 struct infer_function_traits<Result (*)(Args...)> {
43   using type = function_traits<Result(Args...)>;
44 };
45 
46 template <typename Result, typename... Args>
47 struct infer_function_traits<Result(Args...)> {
48   using type = function_traits<Result(Args...)>;
49 };
50 
51 template <typename T>
52 using infer_function_traits_t = typename infer_function_traits<T>::type;
53 
54 /**
55  * make_function_traits: creates a `function_traits` type given a Return type
56  * and a typelist of Argument types
57  *
58  * Example:
59  * bool f(int, int);
60  *
61  * infer_function_traits_t<f> == make_function_traits_t<bool,
62  * typelist::typelist<int, int>>
63  */
64 template <typename Result, typename ArgList>
65 struct make_function_traits {
66   static_assert(
67       false_t<ArgList>::value,
68       "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
69 };
70 
71 template <typename Result, typename... Args>
72 struct make_function_traits<Result, typelist::typelist<Args...>> {
73   using type = function_traits<Result(Args...)>;
74 };
75 
76 template <typename Result, typename ArgList>
77 using make_function_traits_t =
78     typename make_function_traits<Result, ArgList>::type;
79 
80 /**
81  * make_offset_index_sequence<Start, N>
82  * Like make_index_sequence<N>, but starting from Start instead of 0.
83  *
84  * Example:
85  *  make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
86  */
87 template <size_t Start, size_t N, size_t... Is>
88 struct make_offset_index_sequence_impl
89     : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
90   static_assert(
91       static_cast<int>(Start) >= 0,
92       "make_offset_index_sequence: Start < 0");
93   static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
94 };
95 
96 template <size_t Start, size_t... Is>
97 struct make_offset_index_sequence_impl<Start, 0, Is...> {
98   typedef std::index_sequence<Is...> type;
99 };
100 
101 template <size_t Start, size_t N>
102 using make_offset_index_sequence =
103     typename make_offset_index_sequence_impl<Start, N>::type;
104 
105 /**
106  * Use tuple_elements to extract a position-indexed subset of elements
107  * from the argument tuple into a result tuple.
108  *
109  * Example:
110  *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
111  *  std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
112  * 2>());
113  */
114 template <class Tuple, size_t... Is>
115 constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
116   return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
117 }
118 
119 /**
120  * Use tuple_take to extract the first or last n elements from the argument
121  * tuple into a result tuple.
122  *
123  * Example:
124  *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
125  *  std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
126  *  std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
127  */
128 template <class Tuple, int N, class Enable = void>
129 struct TupleTake {};
130 
131 template <class Tuple, int N>
132 struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
133   static auto call(Tuple t) {
134     constexpr size_t size = std::tuple_size<Tuple>();
135     static_assert(N <= size, "tuple_take: N > size");
136     return tuple_elements(t, std::make_index_sequence<N>{});
137   }
138 };
139 
140 template <class Tuple, int N>
141     struct TupleTake < Tuple,
142     N, std::enable_if_t<N<0, void>> {
143   static auto call(Tuple t) {
144     constexpr size_t size = std::tuple_size<Tuple>();
145     static_assert(-N <= size, "tuple_take: -N > size");
146     return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
147   }
148 };
149 
150 template <class Tuple, int N>
151 auto tuple_take(Tuple t) {
152   return TupleTake<Tuple, N>::call(t);
153 }
154 
155 /**
156  * Use tuple_slice to extract a contiguous subtuple from the argument.
157  *
158  * Example:
159  *  std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
160  * "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
161  * tuple_slice<decltype(t), 1, 2>(t);
162  */
163 template <class Tuple, size_t Start, size_t N>
164 constexpr auto tuple_slice(Tuple t) {
165   constexpr size_t size = std::tuple_size<Tuple>();
166   static_assert(Start + N <= size, "tuple_slice: Start + N > size");
167   return tuple_elements(t, make_offset_index_sequence<Start, N>{});
168 }
169 
170 /**
171  * Use tuple_map to run a mapping function over a tuple to get a new tuple.
172  *
173  * Example 1:
174  *   auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
175  * (int32_t a) -> int16_t {return a+1;});
176  *   // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
177  *
178  * Example 2:
179  *   struct Mapper {
180  *     std::string operator()(int32_t a) const {
181  *       return std::to_string(a);
182  *     }
183  *     int64_t operator()(const std::string& a) const {
184  *        return atoi(a.c_str());
185  *     }
186  *   };
187  *   auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
188  * Mapper());
189  *   // result == std::tuple<std::string, int64_t>("3", 4)
190  *
191  * Example 3:
192  *   struct A final {
193  *    int32_t func() {
194  *      return 5;
195  *    }
196  *  };
197  *  struct B final {
198  *    std::string func() {
199  *      return "5";
200  *    }
201  *  };
202  *  auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
203  * a.func(); });
204  *  // result == std::tuple<int32_t, std::string>(5, "5");
205  */
206 namespace detail {
207 template <class Mapper, class... Args, size_t... Indices>
208 auto tuple_map(
209     // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
210     std::tuple<Args...>&& tuple,
211     const Mapper& mapper,
212     std::index_sequence<Indices...>) {
213   return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
214       tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
215 }
216 } // namespace detail
217 
218 template <class Mapper, class... Args>
219 auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
220   return detail::tuple_map(
221       std::move(tuple), mapper, std::index_sequence_for<Args...>());
222 }
223 
224 } // namespace c10::guts
225