xref: /aosp_15_r20/external/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 //===----------------------------------------------------------------------===//
10 /// \file extension/kernel_util/make_boxed_from_unboxed_functor.h
11 /// Defines a template that can be used to create a boxed version of an unboxed
12 /// functor.
13 /// Example usage:
14 /// ```
15 /// Tensor&
16 /// my_op(KernelRuntimeContext& ctx, const Tensor& self, const Tensor& other,
17 ///       Tensor& out)
18 /// {
19 ///   // ...
20 ///   return out;
21 /// }
22 ///
23 /// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
24 ///   EXECUTORCH_FN(my_op));
25 /// static auto res = register_kernels({my_kernel});
26 /// ```
27 /// Or simply:
28 /// ```
29 /// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
30 /// ```
31 ///
32 /// The trick here is to convert each EValue to inferred argument type. This
33 /// uses a lot of C++17 features.
34 //===----------------------------------------------------------------------===//
35 
36 #pragma once
37 #if __cplusplus < 201703L
38 #error "This header requires C++17"
39 #endif
40 
41 #include <executorch/extension/kernel_util/meta_programming.h>
42 #include <executorch/extension/kernel_util/type_list.h>
43 #include <executorch/runtime/core/evalue.h>
44 #include <executorch/runtime/core/exec_aten/exec_aten.h>
45 #include <executorch/runtime/kernel/operator_registry.h>
46 #include <cstdlib>
47 #include <memory>
48 #include <type_traits>
49 #include <typeinfo>
50 
51 namespace executorch {
52 namespace runtime {
53 class KernelRuntimeContext; // Forward declaration
54 } // namespace runtime
55 } // namespace executorch
56 
57 namespace executorch {
58 namespace extension {
59 
60 // This extension has a lot of generic internal names like "size"; use a unique
61 // internal namespace to avoid conflicts with other extensions.
62 namespace kernel_util_internal {
63 
64 template <class T>
65 struct decay_if_not_tensor final {
66   using type = std::decay_t<T>;
67 };
68 template <>
69 struct decay_if_not_tensor<executorch::aten::Tensor&> final {
70   using type = executorch::aten::Tensor&;
71 };
72 template <>
73 struct decay_if_not_tensor<const executorch::aten::Tensor&> final {
74   using type = const executorch::aten::Tensor&;
75 };
76 
77 template <class T>
78 struct evalue_to_arg final {
79   static T call(executorch::runtime::EValue& v) {
80     return std::move(v).to<T>();
81   }
82 };
83 
84 template <>
85 struct evalue_to_arg<executorch::aten::Tensor&> final {
86   static executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
87     return v.toTensor();
88   }
89 };
90 
91 template <>
92 struct evalue_to_arg<const executorch::aten::Tensor&> final {
93   static const executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
94     return v.toTensor();
95   }
96 };
97 
98 template <class T>
99 struct evalue_to_arg<executorch::aten::optional<T>> final {
100   static executorch::aten::optional<T> call(executorch::runtime::EValue& v) {
101     return v.toOptional<T>();
102   }
103 };
104 
105 template <class T>
106 struct evalue_to_arg<executorch::aten::ArrayRef<executorch::aten::optional<T>>>
107     final {
108   static executorch::aten::ArrayRef<executorch::aten::optional<T>> call(
109       executorch::runtime::EValue& v) {
110     return v.toListOptionalTensor();
111   }
112 };
113 
114 template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes>
115 void call_functor_with_args_from_stack(
116     ::executorch::runtime::KernelRuntimeContext& ctx,
117     executorch::runtime::EValue** stack,
118     std::index_sequence<evalue_arg_indices...>,
119     typelist<ArgTypes...>*) {
120   (*Functor::func_ptr())(
121       ctx,
122       evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
123           *stack[evalue_arg_indices])...);
124 }
125 
126 } // namespace kernel_util_internal
127 
128 /**
129  * WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
130  * takes EValues as input and returns void. The wrapped functor will unbox all
131  * inputs and forward them to unboxed kernel.
132  */
133 template <class FuncType>
134 struct WrapUnboxedIntoFunctor {
135   static_assert(
136       kernel_util_internal::is_compile_time_function_pointer<FuncType>::value,
137       "Can't handle function other than EXECUTORCH_FN");
138   using TrueType = typename FuncType::FuncType;
139   using ReturnType = typename kernel_util_internal::infer_function_traits_t<
140       TrueType>::return_type;
141   using ArgsType = typename kernel_util_internal::infer_function_traits_t<
142       TrueType>::parameter_types;
143   // check if the first argument is KernelRuntimeContext, if so, remove it
144   static constexpr bool first_arg_is_context = std::is_same<
145       ::executorch::runtime::KernelRuntimeContext,
146       std::remove_reference_t<
147           kernel_util_internal::head_with_default_t<void, ArgsType>>>::value;
148   using ContextRemovedArgsType = std::conditional_t<
149       first_arg_is_context,
150       kernel_util_internal::drop_if_nonempty_t<ArgsType, 1>,
151       ArgsType>;
152 
153   static void call(
154       ::executorch::runtime::KernelRuntimeContext& ctx,
155       executorch::runtime::EValue** stack) {
156     constexpr size_t num_inputs =
157         kernel_util_internal::size<ContextRemovedArgsType>::value;
158     return kernel_util_internal::call_functor_with_args_from_stack<FuncType>(
159         ctx,
160         stack,
161         std::make_index_sequence<num_inputs>(),
162         static_cast<ContextRemovedArgsType*>(nullptr));
163   }
164 };
165 
166 template <typename FuncType>
167 static executorch::runtime::Kernel make_boxed_kernel(
168     const char* name,
169     FuncType) {
170   return executorch::runtime::Kernel(
171       name, WrapUnboxedIntoFunctor<FuncType>::call);
172 }
173 
174 } // namespace extension
175 } // namespace executorch
176 
177 // Inspired from C10_CONCATENATE
178 #define ET_CONCATENATE_IMPL(s1, s2) s1##s2
179 #define ET_CONCATENATE(s1, s2) ET_CONCATENATE_IMPL(s1, s2)
180 #define ET_UID __LINE__
181 
182 #define EXECUTORCH_LIBRARY(ns, op_name, func) \
183   _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID)
184 
185 #define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
186   static auto ET_CONCATENATE(res_##ns##_, uid) =         \
187       ::executorch::runtime::register_kernel(            \
188           ::executorch::extension::make_boxed_kernel(    \
189               #ns "::" op_name, EXECUTORCH_FN(func)))
190 
191 namespace torch {
192 namespace executor {
193 // TODO(T197294990): Remove these deprecated aliases once all users have moved
194 // to the new `::executorch` namespaces.
195 using ::executorch::extension::make_boxed_kernel;
196 using ::executorch::extension::WrapUnboxedIntoFunctor;
197 } // namespace executor
198 } // namespace torch
199