1 #pragma once 2 3 #include <ATen/core/boxing/OperatorKernel.h> 4 #include <ATen/core/ivalue.h> 5 #include <ATen/core/stack.h> 6 #include <c10/util/TypeList.h> 7 #include <ATen/core/IListRef.h> 8 #include <c10/util/intrusive_ptr.h> 9 #include <c10/util/Metaprogramming.h> 10 11 #include <utility> 12 13 namespace c10 { 14 15 using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace. 16 class OperatorHandle; 17 18 /* 19 * [Note: Argument forwarding in the dispatcher] 20 * 21 * The dispatcher uses a somewhat unusual way to forward arguments through several layers of 22 * wrapper functions. This can be confusing because an experienced C++ programmer would look at this 23 * and think "oh this is supposed to be forwarding a universal reference but the && is missing. This is a bug.". 24 * It is not a bug. The common way in C++ to forward arguments is to use universal references: 25 * 26 * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); } 27 * 28 * but that relies on inferring the correct reference type (i.e. value vs & vs &&) from the argument. 29 * In our case, we cannot rely on the argument as supplied by the caller, because that could infer a 30 * different reference type than was used in the kernel function. The correct reference type 31 * is dictated by the kernel signature and must be identical since we cast function pointers 32 * through void* pointers and mismatches would be UB. So we need a forwarding pattern that determines 33 * the reference type to use by looking at the explicitly supplied operator signature, not by looking at 34 * the argument we're calling it with. 35 * 36 * What does std::forward do, exactly? 37 * ------------------------------------ 38 * std::forward<T>(t) is a way to cast t to the reference type supplied in T. 39 * Let's assume decay_t<T> == U and T is either U or some reference of U. 40 * - std::forward<T&>(t) will return U&, no matter what kind of reference t is. 41 * - std::forward<T&&>(t) will return U&&, no matter what kind of reference t is. 42 * - std::forward<T>(t) will return U&& (not U!), no matter what kind of reference t is. 43 * 44 * For universal references, that means that in the following function 45 * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); } 46 * 47 * - when called with arg being a rvalue reference or non-reference value, T gets inferred to be 48 * a non-reference U, and std::forward<T>(t) will return U&&, correctly moving the argument. 49 * - when called with arg behind a lvalue reference, T gets inferred to be U& because that's the only 50 * way to match the signature (in C++, a type that is (T&)&& will collapse to T&). 51 * That means std::forward<T>(t) will return U& and the value will not be moved but passed on as 52 * a lvalue reference. 53 * 54 * How do we use that? 55 * ------------------------------------ 56 * But std::forward can also be used outside of the common "universal forwarding" pattern to change 57 * reference types. So instead of following the common C++ pattern, we notice what 58 * std::forward<T>() actually does, and that is it takes a value and changes its reference to the 59 * type of reference passed in as T. If we don't infer T but explicitly specify it, we can use this 60 * to forward based on an explicitly specified reference type instead of the inferred argument type. 61 * 62 * This is why many of the dispatcher functions look like 63 * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); } 64 * instead of the common 65 * > template<class T> func(T&& t) { func2(std::forward<T>(t)); } 66 * 67 * and are expected to be called by explicitly specifying the template parameters in a way that matches 68 * the expected operator signature at each call site. 69 */ 70 71 namespace impl { 72 // supported_primitive_arg_types defines which primitive types we allow in 73 // kernel functions as arguments or returns. 74 // Additionally, we support lists, dicts and optionals containing these types. 75 using supported_primitive_arg_types = guts::typelist::typelist< 76 int64_t, 77 double, 78 bool, 79 c10::string_view, 80 at::Tensor, 81 at::Scalar, 82 c10::QScheme, 83 c10::ScalarType, 84 c10::Device, 85 c10::DeviceIndex, 86 c10::Layout, 87 c10::MemoryFormat, 88 at::Dimname 89 >; 90 91 // We have an unboxed functor in hand that takes C++ arguments, and 92 // we're building a boxed functor wrapper for it that takes IValues. 93 // So "outside" is boxed and "inside" is unboxed. 94 // 95 // So a valid input type is one that our boxed functor wrapper can 96 // unbox from an IValue into a C++ value. 97 // 98 // Whereas a valid output type is one that our wrapper can recieve 99 // as a C++ value from the unboxed functor, and box into an IValue. 100 101 // 102 // assert_is_valid_input_type 103 // checks that T can be unboxed from an IValue into a C++ value. 104 // 105 106 template<class T, bool AllowDeprecatedTypes, class Enable = void> 107 struct assert_is_valid_input_type { assert_is_valid_input_typeassert_is_valid_input_type108 assert_is_valid_input_type() { 109 if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::value) { 110 /* everything is ok, this is a primitive type */ 111 } else { 112 /* otherwise this must be an instance of a valid custom class, since it can only 113 have been created via IValue(x), which ensures this. */ 114 } 115 } 116 }; 117 118 template<class T, bool AllowDeprecatedTypes> 119 struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes> 120 : assert_is_valid_input_type<T, AllowDeprecatedTypes> {}; 121 122 template <bool AllowDeprecatedTypes, class... Args> 123 struct TypeCheckHelper; 124 125 template <bool AllowDeprecatedTypes> 126 struct TypeCheckHelper<AllowDeprecatedTypes> {}; 127 128 template <bool AllowDeprecatedTypes, class Head, class... Rest> 129 struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...> 130 : TypeCheckHelper<AllowDeprecatedTypes, Rest...> { 131 assert_is_valid_input_type<Head, AllowDeprecatedTypes> check; 132 }; 133 134 template<class... Contained, bool AllowDeprecatedTypes> 135 struct assert_is_valid_input_type<std::tuple<Contained...>, AllowDeprecatedTypes> 136 : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {}; 137 138 template<class Key, class Value, bool AllowDeprecatedTypes> 139 struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes> 140 : assert_is_valid_input_type<Value, AllowDeprecatedTypes> { 141 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, 142 "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string."); 143 }; 144 145 template<class Key, class Value, bool AllowDeprecatedTypes> 146 struct assert_is_valid_input_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes> 147 : assert_is_valid_input_type<Value, AllowDeprecatedTypes> { 148 static_assert(AllowDeprecatedTypes, 149 "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead."); 150 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, 151 "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string."); 152 }; 153 154 template<class T, bool AllowDeprecatedTypes> 155 struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes> 156 : assert_is_valid_input_type<T, AllowDeprecatedTypes> { 157 static_assert(!std::is_same<T, at::Scalar>::value, 158 "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead."); 159 }; 160 161 template<class T, bool AllowDeprecatedTypes> 162 struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes> 163 : assert_is_valid_input_type<T, AllowDeprecatedTypes> { 164 static_assert(!std::is_same<T, at::Scalar>::value, 165 "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead."); 166 }; 167 168 template<class T, bool AllowDeprecatedTypes> 169 struct assert_is_valid_input_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes> 170 : assert_is_valid_input_type<T, AllowDeprecatedTypes> { 171 static_assert(!std::is_same<T, at::Scalar>::value, 172 "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead."); 173 }; 174 175 template<class T, size_t N, bool AllowDeprecatedTypes> 176 struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes> 177 : assert_is_valid_input_type<T, AllowDeprecatedTypes> { 178 static_assert(!std::is_same<T, at::Scalar>::value, 179 "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead."); 180 }; 181 182 template<class T, bool AllowDeprecatedTypes> 183 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> { 184 // There is no reason to support float when we have double. Keep the API lean. 185 static_assert(guts::false_t<T>::value, 186 "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); 187 }; 188 template<class T, bool AllowDeprecatedTypes> 189 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> { 190 static_assert(guts::false_t<T>::value, 191 "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead."); 192 }; 193 template<class T, bool AllowDeprecatedTypes> 194 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> { 195 static_assert(guts::false_t<T>::value, 196 "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead."); 197 }; 198 template<class T, bool AllowDeprecatedTypes> 199 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { 200 static_assert(guts::false_t<T>::value, 201 "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); 202 }; 203 template<class T, bool AllowDeprecatedTypes> 204 struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const c10::SymInt&, T>::value>> { 205 static_assert(guts::false_t<T>::value, 206 "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead."); 207 }; 208 209 // TODO: it probably would be good to tighten this up quite a bit more with 210 // an explicit list for everything 211 212 // 213 // assert_is_valid_output_type 214 // 215 216 template<class T, bool AllowDeprecatedTypes, class Enable = void> 217 struct assert_is_valid_output_type { 218 assert_is_valid_output_type() { 219 if constexpr(guts::typelist::contains<supported_primitive_arg_types, T>::value) { 220 /* everything is ok, this is a primitive type */ 221 } else { 222 /* otherwise T is verified to be a registered custom class in the IValue 223 constructor, so no benefit in double-checking here */ 224 } 225 } 226 }; 227 228 template<class T, bool AllowDeprecatedTypes> 229 struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes> 230 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {}; 231 232 template<class T, bool AllowDeprecatedTypes> 233 struct assert_is_valid_output_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes> 234 : assert_is_valid_output_type<T, AllowDeprecatedTypes> {}; 235 236 template<class Key, class Value, bool AllowDeprecatedTypes> 237 struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes> 238 : assert_is_valid_output_type<Value, AllowDeprecatedTypes> { 239 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, 240 "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string."); 241 static_assert(!std::is_same<Value, at::Scalar>::value, 242 "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>."); 243 }; 244 245 template<class Key, class Value, bool AllowDeprecatedTypes> 246 struct assert_is_valid_output_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes> 247 : assert_is_valid_output_type<Value, AllowDeprecatedTypes> { 248 static_assert(AllowDeprecatedTypes, 249 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead."); 250 static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value, 251 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string."); 252 static_assert(!std::is_same<Value, at::Scalar>::value, 253 "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>."); 254 }; 255 256 template<class T, bool AllowDeprecatedTypes> 257 struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes> 258 : assert_is_valid_output_type<T, AllowDeprecatedTypes> { 259 static_assert(!std::is_same<T, at::Scalar>::value, 260 "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead."); 261 }; 262 263 template<class T, bool AllowDeprecatedTypes> 264 struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes> 265 : assert_is_valid_output_type<T, AllowDeprecatedTypes> { 266 static_assert(!std::is_same<T, at::Scalar>::value, 267 "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead."); 268 // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead."); 269 }; 270 271 template<class T, size_t N, bool AllowDeprecatedTypes> 272 struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes> 273 : assert_is_valid_output_type<T, AllowDeprecatedTypes> { 274 static_assert(!std::is_same<T, at::Scalar>::value, 275 "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead."); 276 }; 277 278 // The following specialisations of assert_is_valid_output_type are technically not 279 // necessary since we would hit the base case and show an error message 280 // there if they didn't exist, but we can show a better error message 281 // in some common error scenarios. 282 template<class T, bool AllowDeprecatedTypes> 283 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> { 284 // There is no reason to support float when we have double. Keep the API lean. 285 static_assert(guts::false_t<T>::value, 286 "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string."); 287 }; 288 template<class T, bool AllowDeprecatedTypes> 289 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> { 290 static_assert(guts::false_t<T>::value, 291 "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead."); 292 }; 293 template<class T, bool AllowDeprecatedTypes> 294 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> { 295 static_assert(guts::false_t<T>::value, 296 "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead."); 297 }; 298 template<class T, bool AllowDeprecatedTypes> 299 struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> { 300 static_assert(guts::false_t<T>::value, 301 "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string."); 302 }; 303 304 // ivalue_to_arg 305 306 template<class T> 307 struct decay_if_not_tensor final { 308 using type = std::decay_t<T>; 309 }; 310 311 template<> 312 struct decay_if_not_tensor<at::Tensor&> final { 313 using type = at::Tensor&; 314 }; 315 316 template<> 317 struct decay_if_not_tensor<const at::Tensor&> final { 318 using type = const at::Tensor&; 319 }; 320 321 template<class T, bool AllowDeprecatedTypes> 322 struct ivalue_to_arg final { 323 static decltype(auto) call(IValue& v) { 324 assert_is_valid_input_type<T, AllowDeprecatedTypes>(); 325 return std::move(v).to<T>(); 326 } 327 }; 328 329 // The following two specializations take advantage of specialized 330 // `toTensor()` overloads on IValue to avoid copying. 331 template<bool AllowDeprecatedTypes> 332 struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final { 333 // We cannot use the default implementation if they asked for a 334 // `at::Tensor&` because it moves from the IValue, so it can't get 335 // an lvalue reference. 336 static at::Tensor& call(IValue& v) { 337 // Tensor& is valid, don't bother asserting 338 return v.toTensor(); 339 } 340 }; 341 342 template<bool AllowDeprecatedTypes> 343 struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final { 344 // We should not use the default implementation if they asked for 345 // a `const at::Tensor&` because it moves from the IValue and they 346 // didn't ask for that. 347 static const at::Tensor& call(IValue& v) { 348 // const Tensor& is valid, don't bother asserting 349 return v.toTensor(); 350 } 351 }; 352 353 template<bool AllowDeprecatedTypes> 354 struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final { 355 static List<at::Tensor> call(IValue& v) { 356 return v.toTensorList(); 357 } 358 }; 359 360 template<class T, bool AllowDeprecatedTypes> 361 struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final { 362 // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that 363 // to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>. 364 static std::vector<T> call(IValue& v) { 365 return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v); 366 } 367 }; 368 template<bool AllowDeprecatedTypes> 369 struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final { 370 static std::vector<c10::SymInt> call(IValue& v) { 371 if (v.isIntList()) { 372 std::vector<c10::SymInt> r; 373 auto src = v.toIntList(); 374 std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); }); 375 return r; 376 } else { 377 return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v); 378 } 379 } 380 }; 381 template<bool AllowDeprecatedTypes> 382 struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes> final { 383 static OptionalArray<c10::SymInt> call(IValue& v) { 384 if (v.isIntList()) { 385 std::vector<c10::SymInt> r; 386 auto src = v.toIntList(); 387 std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); }); 388 return OptionalArray<c10::SymInt>(std::move(r)); 389 } else { 390 return std::move(v).to<OptionalArray<c10::SymInt>>(); 391 } 392 } 393 }; 394 template<class T, bool AllowDeprecatedTypes> 395 struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final { 396 // If an argument is std::optional<ArrayRef<T>>, convert the IValue to an std::optional<std::vector<T>> and pass that 397 // to the operator. OptionalArray<T> is basically a std::optional<std::vector<T>> but implicitly convertible 398 // to std::optional<ArrayRef<T>>. 399 static OptionalArray<T> call(IValue& v) { 400 return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v); 401 } 402 }; 403 404 template<class T, bool AllowDeprecatedTypes> 405 struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final { 406 // If an argument is OptionalArrayRef<T>, convert the IValue to an 407 // std::optional<std::vector<T>> and pass that to the operator. OptionalArray<T> 408 // is basically a std::optional<std::vector<T>> but implicitly convertible to 409 // OptionalArrayRef<T> 410 static OptionalArray<T> call(IValue& v) { 411 return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v); 412 } 413 }; 414 415 // return_to_ivalue 416 template<class T, bool AllowDeprecatedTypes, class Enable = void> 417 struct return_to_ivalue final {}; 418 419 template<class T, bool AllowDeprecatedTypes> 420 struct return_to_ivalue<T, AllowDeprecatedTypes, std::enable_if_t<!std::is_same<at::Tensor&, T>::value>> final { 421 static IValue call(T&& v) { 422 assert_is_valid_output_type<T, AllowDeprecatedTypes>(); 423 return c10::ivalue::from(std::move(v)); 424 } 425 static IValue copy(const T& v) { 426 assert_is_valid_output_type<T, AllowDeprecatedTypes>(); 427 return IValue(v); 428 } 429 }; 430 431 // Special case to allow kernels to return `Tensor&`. 432 // TODO Delete this once kernels don't do that anymore 433 template<bool AllowDeprecatedTypes> 434 struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final { 435 static IValue call(at::Tensor& v) { 436 return c10::ivalue::from(v); 437 } 438 static IValue copy(at::Tensor& v) { 439 return IValue(v); 440 } 441 }; 442 443 // wrap_kernel_functor_unboxed_ 444 445 template<class KernelFunctor, class OpSignature> 446 struct wrap_kernel_functor_unboxed_ final {}; 447 448 // This specialization is for kernels with a first argument that is NOT of type DispatchKeySet 449 // This includes kernels with 0 arguments. 450 template<class KernelFunctor, class ReturnType, class... ParameterTypes> 451 struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(ParameterTypes...)> final { 452 static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value, 453 "Return type mismatch"); 454 static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value, 455 "Parameter types mismatch"); 456 457 // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use && 458 static ReturnType call(OperatorKernel* functor, DispatchKeySet, ParameterTypes... args) { 459 KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor); 460 // Note [Plumbing Keys Through The Dispatcher 2] 461 // See Note [Plumbing Keys Through The Dispatcher] for the background. 462 // This functor explicitly takes in a dispatchKeySet and drops it on the floor- it does not forward it to the registered kernel. 463 // 464 // This is due to the calling convention within the dispatcher, which expects all registered kernels to have a first argument of type 465 // DispatchKeySet. 466 // This is not the case for pretty much all manually written kernels, however- this functor serves to separate the calling convention 467 // of the dispatcher from the calling convention of manually written kernels. 468 return (*functor_)(std::forward<ParameterTypes>(args)...); 469 } 470 }; 471 472 // This specialization is for kernels with a first argument of type DispatchKeySet 473 template<class KernelFunctor, class ReturnType, class... ParameterTypes> 474 struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(DispatchKeySet, ParameterTypes...)> final { 475 static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value, 476 "Return type mismatch"); 477 static_assert(std::is_same<guts::typelist::typelist<DispatchKeySet, ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value, 478 "Parameter types mismatch"); 479 480 // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use && 481 static ReturnType call(OperatorKernel* functor, DispatchKeySet dispatchKeySet, ParameterTypes... args) { 482 KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor); 483 // We're explicitly taking in a dispatchKeySet and forwarding it to the registered kernel. 484 // See Note [Plumbing Keys Through The Dispatcher 2] for details. 485 return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...); 486 } 487 }; 488 489 template<class KernelFunctor> 490 using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<KernelFunctor, typename guts::infer_function_traits_t<KernelFunctor>::func_type>; 491 492 // call_functor_with_args_from_stack 493 494 template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices, typename... ArgTypes> 495 std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type> 496 call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) { 497 (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning. 498 499 // We're explicitly filtering out DispatchKeySet from the argument list. 500 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. 501 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. 502 // See Note [Plumbing Keys Through The Dispatcher] for the background. 503 return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet, 504 ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call( 505 torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)) 506 )...); 507 } 508 509 template<class Functor, bool AllowDeprecatedTypes> 510 std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type> 511 call_functor_with_args_from_stack(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack) { 512 // We're explicitly filtering out DispatchKeySet from the argument list. 513 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. 514 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. 515 // See Note [Plumbing Keys Through The Dispatcher] for the background. 516 using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<Functor>::parameter_types; 517 constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value; 518 return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack, std::make_index_sequence<num_ivalue_args>(), static_cast<ArgTypes*>(nullptr)); 519 } 520 521 // push_outputs 522 523 template<class OutputType, bool AllowDeprecatedTypes> 524 struct push_outputs final { 525 // Contrary to [Note: Argument forwarding in the dispatcher], we use OutputType&& here 526 // to avoid one extra call to the move constructor in this case. This is still not a 527 // universal reference though because OutputType is an explicitly specified class 528 // template parameter. 529 static void call(OutputType&& output, Stack* stack) { 530 torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output))); 531 } 532 static void copy(const OutputType& output, Stack* stack) { 533 torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output)); 534 } 535 }; 536 template<class... OutputTypes, bool AllowDeprecatedTypes> 537 struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final { 538 static void call(std::tuple<OutputTypes...>&& output, Stack* stack) { 539 call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>()); 540 } 541 static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) { 542 copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>()); 543 } 544 545 private: 546 template<size_t... indices> 547 static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) { 548 torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...); 549 } 550 template<size_t... indices> 551 static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) { 552 torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...); 553 } 554 }; 555 template<bool AllowDeprecatedTypes> 556 struct push_outputs<void, AllowDeprecatedTypes> final { 557 static void call(int /*dummy*/, Stack* /*stack*/) { 558 } 559 static void copy(int /*dummy*/, Stack* /*stack*/) { 560 } 561 }; 562 563 // make_boxed_from_unboxed_functor 564 565 template<class KernelFunctor, bool AllowDeprecatedTypes> 566 struct make_boxed_from_unboxed_functor final { 567 static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, 568 "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); 569 570 static void call(OperatorKernel* functor, const OperatorHandle&, DispatchKeySet dispatchKeySet, Stack* stack) { 571 using ReturnType = typename guts::infer_function_traits_t<KernelFunctor>::return_type; 572 // We're explicitly filtering out DispatchKeySet from the argument list. 573 // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher. 574 // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack. 575 // See Note [Plumbing Keys Through The Dispatcher] for the background. 576 using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::parameter_types; 577 constexpr bool has_outputs = !std::is_same<void, ReturnType>::value; 578 constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value; 579 if constexpr (has_outputs) { 580 // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value 581 // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`. 582 // [Note: VC++ and 'std': ambiguous symbol] 583 using ReturnType_ = ::std::decay_t<ReturnType>; 584 ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack); 585 torch::jit::drop(*stack, num_inputs); 586 // See note [ VC++ and 'std': ambiguous symbol] 587 push_outputs<ReturnType_, AllowDeprecatedTypes>::call(::std::move(output), stack); 588 } else { 589 call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack); 590 torch::jit::drop(*stack, num_inputs); 591 } 592 } 593 }; 594 } // namespace impl 595 596 } // namespace c10 597 598 namespace torch { 599 using OperatorKernel = c10::OperatorKernel; 600 } 601