xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/infer_schema.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /**
4  * This file contains functionality to take a C++ function and infer its
5  * c10::FunctionSchema.
6  */
7 
8 #include <ATen/core/function_schema.h>
9 #include <c10/util/Metaprogramming.h>
10 
11 namespace c10 {
12 namespace detail {
13 
14 namespace infer_schema {
15 
16 /// The templated inference code creates `ArgumentDef` instead of `Argument`,
17 /// because that can be constructed at compile time and has a much smaller
18 /// binary size than having calls to `Argument` constructors in the template.
19 /// Creating `Argument` objects from `ArgumentDef` can then be done at
20 /// runtime in a non-templated way.
21 struct ArgumentDef final {
22   using GetTypeFn = TypePtr();
23   GetTypeFn* getTypeFn;
24   GetTypeFn* getFakeTypeFn;
ArgumentDeffinal25   constexpr ArgumentDef(): getTypeFn(nullptr), getFakeTypeFn(nullptr) {}
ArgumentDeffinal26   explicit constexpr ArgumentDef(GetTypeFn *getTypeFn, GetTypeFn *getFakeTypeFn): getTypeFn(getTypeFn), getFakeTypeFn(getFakeTypeFn) {}
27 };
28 
29 template<bool V>
30 struct bool_t {};
31 template<> struct bool_t<true> : std::true_type {};
32 template<> struct bool_t<false> : std::false_type {};
33 
34 /// Checks the static C++ types `Types` for correctness to catch common error cases.
35 template <class... Types>
36 constexpr int checkStaticTypes() {
37  // Give nice error messages for some of the common error cases.
38  // Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
39  static_assert(std::conjunction<
40      bool_t<!std::is_integral<Types>::value || std::is_same<Types, int8_t>::value || std::is_same<Types, int64_t>::value || std::is_same<Types, bool>::value>...
41    >::value, "INVALID TYPE: Only int8_t, int64_t and bool are supported as an integral argument type");
42  static_assert(std::conjunction<
43      bool_t<!std::is_same<Types, float>::value>...
44    >::value, "INVALID TYPE: float is not supported as an argument type, use double instead");
45  return 0;
46 }
47 
48 template <typename... Ts, size_t... Is>
49 constexpr std::array<ArgumentDef, sizeof...(Ts)> createArgumentVectorFromTypes(std::index_sequence<Is...>) {
50   return (
51     // Check types for common errors
52     checkStaticTypes<Ts...>(),
53 
54     // Create the return value
55     std::array<ArgumentDef, sizeof...(Ts)>{
56       ArgumentDef(&getTypePtrCopy<std::decay_t<Ts>>, &getFakeTypePtrCopy<std::decay_t<Ts>>)...}
57   );
58 }
59 
60 /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
61 /// as template arguments.
62 template<class ParameterTypes> struct createArguments final {};
63 template<class... ParameterTypes>
64 struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
65   static constexpr std::array<ArgumentDef, sizeof...(ParameterTypes)> call() {
66     return createArgumentVectorFromTypes<ParameterTypes...>(
67         std::make_index_sequence<sizeof...(ParameterTypes)>()
68     );
69   }
70 };
71 
72 /// Creates a vector of `ArgumentDef` from a list of C++ types that are specified
73 /// as a tuple (i.e. in the way c10 kernels return values).
74 /// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
75 /// It can be an empty tuple<>, or void for kernels that don't return anything.
76 /// It can be a single type A (i.e. no tuple) for the case where a kernel just
77 /// returns one value.
78 template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
79 
80 template<class... ReturnTypes>
81 struct createReturns<std::tuple<ReturnTypes...>, void> final {
82   static constexpr std::array<ArgumentDef, sizeof...(ReturnTypes)> call() {
83     return createArgumentVectorFromTypes<ReturnTypes...>(
84         std::make_index_sequence<sizeof...(ReturnTypes)>()
85     );
86   }
87 };
88 
89 template<class ReturnType>
90 struct createReturns<ReturnType, std::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
91   static constexpr std::array<ArgumentDef, 1> call() {
92     return createReturns<std::tuple<ReturnType>>::call();
93   }
94 };
95 
96 template<>
97 struct createReturns<void, void> final {
98   static constexpr std::array<ArgumentDef, 0> call() {
99     return createReturns<std::tuple<>>::call();
100   }
101 };
102 
103 template <typename ReturnType>
104 struct createSingleReturn {
105   static constexpr std::array<ArgumentDef, 1> call() {
106     return createArgumentVectorFromTypes<ReturnType>(std::make_index_sequence<1>());
107   }
108 };
109 
110 TORCH_API FunctionSchema make_function_schema(std::string&& name, std::string&& overload_name, c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
111 TORCH_API FunctionSchema make_function_schema(c10::ArrayRef<ArgumentDef> arguments, c10::ArrayRef<ArgumentDef> returns);
112 
113 /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
114 /// function. Flattens std::tuple returns into multiple return types
115 template <typename FunctionTraits>
116 FunctionSchema createFunctionSchemaFromTraitsFlattenedReturns() {
117  using ReturnType = typename FunctionTraits::return_type;
118  using ParameterTypes = typename FunctionTraits::parameter_types;
119 
120  // arguments and returns are computed into a std::array at compile time and embedded into the binary.
121  // The only code executed at runtime here is the one that creates a std::vector
122  // of the arguments/returns from the std::array.
123  constexpr auto arguments = createArguments<ParameterTypes>::call();
124  constexpr auto returns = createReturns<ReturnType>::call();
125 
126  return make_function_schema(arguments, returns);
127 }
128 
129 /// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
130 /// function. Preserves std::tuple returns as a Tuple return type
131 template <typename FunctionTraits>
132 FunctionSchema createFunctionSchemaFromTraitsSingleReturn(std::string&& name, std::string&& overload_name) {
133  using ReturnType = typename FunctionTraits::return_type;
134  using ParameterTypes = typename FunctionTraits::parameter_types;
135 
136  // arguments and returns are computed into a std::array at compile time and embedded into the binary.
137  // The only code executed at runtime here is the one that creates a std::vector
138  // of the arguments/returns from the std::array.
139  constexpr auto arguments = createArguments<ParameterTypes>::call();
140  constexpr auto returns = createSingleReturn<ReturnType>::call();
141 
142  return make_function_schema(std::move(name), std::move(overload_name), arguments, returns);
143 }
144 
145 }
146 }
147 
148 template<class FuncType>
149 FunctionSchema inferFunctionSchemaFlattenedReturns() {
150   return detail::infer_schema::createFunctionSchemaFromTraitsFlattenedReturns<guts::infer_function_traits_t<FuncType>>();
151 }
152 
153 template<class FuncType>
154 FunctionSchema inferFunctionSchemaSingleReturn(std::string&& name, std::string&& overload_name) {
155   return detail::infer_schema::createFunctionSchemaFromTraitsSingleReturn<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
156 }
157 
158 TORCH_API std::optional<std::string> findSchemaDifferences(const FunctionSchema& inferred, const FunctionSchema& specified);
159 
160 }
161