1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 //===----------------------------------------------------------------------===// 10 /// \file extension/kernel_util/make_boxed_from_unboxed_functor.h 11 /// Defines a template that can be used to create a boxed version of an unboxed 12 /// functor. 13 /// Example usage: 14 /// ``` 15 /// Tensor& 16 /// my_op(KernelRuntimeContext& ctx, const Tensor& self, const Tensor& other, 17 /// Tensor& out) 18 /// { 19 /// // ... 20 /// return out; 21 /// } 22 /// 23 /// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", 24 /// EXECUTORCH_FN(my_op)); 25 /// static auto res = register_kernels({my_kernel}); 26 /// ``` 27 /// Or simply: 28 /// ``` 29 /// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op); 30 /// ``` 31 /// 32 /// The trick here is to convert each EValue to inferred argument type. This 33 /// uses a lot of C++17 features. 34 //===----------------------------------------------------------------------===// 35 36 #pragma once 37 #if __cplusplus < 201703L 38 #error "This header requires C++17" 39 #endif 40 41 #include <executorch/extension/kernel_util/meta_programming.h> 42 #include <executorch/extension/kernel_util/type_list.h> 43 #include <executorch/runtime/core/evalue.h> 44 #include <executorch/runtime/core/exec_aten/exec_aten.h> 45 #include <executorch/runtime/kernel/operator_registry.h> 46 #include <cstdlib> 47 #include <memory> 48 #include <type_traits> 49 #include <typeinfo> 50 51 namespace executorch { 52 namespace runtime { 53 class KernelRuntimeContext; // Forward declaration 54 } // namespace runtime 55 } // namespace executorch 56 57 namespace executorch { 58 namespace extension { 59 60 // This extension has a lot of generic internal names like "size"; use a unique 61 // internal namespace to avoid conflicts with other extensions. 62 namespace kernel_util_internal { 63 64 template <class T> 65 struct decay_if_not_tensor final { 66 using type = std::decay_t<T>; 67 }; 68 template <> 69 struct decay_if_not_tensor<executorch::aten::Tensor&> final { 70 using type = executorch::aten::Tensor&; 71 }; 72 template <> 73 struct decay_if_not_tensor<const executorch::aten::Tensor&> final { 74 using type = const executorch::aten::Tensor&; 75 }; 76 77 template <class T> 78 struct evalue_to_arg final { 79 static T call(executorch::runtime::EValue& v) { 80 return std::move(v).to<T>(); 81 } 82 }; 83 84 template <> 85 struct evalue_to_arg<executorch::aten::Tensor&> final { 86 static executorch::aten::Tensor& call(executorch::runtime::EValue& v) { 87 return v.toTensor(); 88 } 89 }; 90 91 template <> 92 struct evalue_to_arg<const executorch::aten::Tensor&> final { 93 static const executorch::aten::Tensor& call(executorch::runtime::EValue& v) { 94 return v.toTensor(); 95 } 96 }; 97 98 template <class T> 99 struct evalue_to_arg<executorch::aten::optional<T>> final { 100 static executorch::aten::optional<T> call(executorch::runtime::EValue& v) { 101 return v.toOptional<T>(); 102 } 103 }; 104 105 template <class T> 106 struct evalue_to_arg<executorch::aten::ArrayRef<executorch::aten::optional<T>>> 107 final { 108 static executorch::aten::ArrayRef<executorch::aten::optional<T>> call( 109 executorch::runtime::EValue& v) { 110 return v.toListOptionalTensor(); 111 } 112 }; 113 114 template <class Functor, size_t... evalue_arg_indices, typename... ArgTypes> 115 void call_functor_with_args_from_stack( 116 ::executorch::runtime::KernelRuntimeContext& ctx, 117 executorch::runtime::EValue** stack, 118 std::index_sequence<evalue_arg_indices...>, 119 typelist<ArgTypes...>*) { 120 (*Functor::func_ptr())( 121 ctx, 122 evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call( 123 *stack[evalue_arg_indices])...); 124 } 125 126 } // namespace kernel_util_internal 127 128 /** 129 * WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that 130 * takes EValues as input and returns void. The wrapped functor will unbox all 131 * inputs and forward them to unboxed kernel. 132 */ 133 template <class FuncType> 134 struct WrapUnboxedIntoFunctor { 135 static_assert( 136 kernel_util_internal::is_compile_time_function_pointer<FuncType>::value, 137 "Can't handle function other than EXECUTORCH_FN"); 138 using TrueType = typename FuncType::FuncType; 139 using ReturnType = typename kernel_util_internal::infer_function_traits_t< 140 TrueType>::return_type; 141 using ArgsType = typename kernel_util_internal::infer_function_traits_t< 142 TrueType>::parameter_types; 143 // check if the first argument is KernelRuntimeContext, if so, remove it 144 static constexpr bool first_arg_is_context = std::is_same< 145 ::executorch::runtime::KernelRuntimeContext, 146 std::remove_reference_t< 147 kernel_util_internal::head_with_default_t<void, ArgsType>>>::value; 148 using ContextRemovedArgsType = std::conditional_t< 149 first_arg_is_context, 150 kernel_util_internal::drop_if_nonempty_t<ArgsType, 1>, 151 ArgsType>; 152 153 static void call( 154 ::executorch::runtime::KernelRuntimeContext& ctx, 155 executorch::runtime::EValue** stack) { 156 constexpr size_t num_inputs = 157 kernel_util_internal::size<ContextRemovedArgsType>::value; 158 return kernel_util_internal::call_functor_with_args_from_stack<FuncType>( 159 ctx, 160 stack, 161 std::make_index_sequence<num_inputs>(), 162 static_cast<ContextRemovedArgsType*>(nullptr)); 163 } 164 }; 165 166 template <typename FuncType> 167 static executorch::runtime::Kernel make_boxed_kernel( 168 const char* name, 169 FuncType) { 170 return executorch::runtime::Kernel( 171 name, WrapUnboxedIntoFunctor<FuncType>::call); 172 } 173 174 } // namespace extension 175 } // namespace executorch 176 177 // Inspired from C10_CONCATENATE 178 #define ET_CONCATENATE_IMPL(s1, s2) s1##s2 179 #define ET_CONCATENATE(s1, s2) ET_CONCATENATE_IMPL(s1, s2) 180 #define ET_UID __LINE__ 181 182 #define EXECUTORCH_LIBRARY(ns, op_name, func) \ 183 _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID) 184 185 #define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \ 186 static auto ET_CONCATENATE(res_##ns##_, uid) = \ 187 ::executorch::runtime::register_kernel( \ 188 ::executorch::extension::make_boxed_kernel( \ 189 #ns "::" op_name, EXECUTORCH_FN(func))) 190 191 namespace torch { 192 namespace executor { 193 // TODO(T197294990): Remove these deprecated aliases once all users have moved 194 // to the new `::executorch` namespaces. 195 using ::executorch::extension::make_boxed_kernel; 196 using ::executorch::extension::WrapUnboxedIntoFunctor; 197 } // namespace executor 198 } // namespace torch 199