xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/boxing/OperatorKernel.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/stack.h>
6 #include <c10/util/TypeList.h>
7 #include <ATen/core/IListRef.h>
8 #include <c10/util/intrusive_ptr.h>
9 #include <c10/util/Metaprogramming.h>
10 
11 #include <utility>
12 
13 namespace c10 {
14 
15 using Stack = torch::jit::Stack; // TODO Instead of this, move torch::jit::Stack to the c10 namespace.
16 class OperatorHandle;
17 
18 /*
19  * [Note: Argument forwarding in the dispatcher]
20  *
21  * The dispatcher uses a somewhat unusual way to forward arguments through several layers of
22  * wrapper functions. This can be confusing because an experienced C++ programmer would look at this
23  * and think "oh this is supposed to be forwarding a universal reference but the && is missing. This is a bug.".
24  * It is not a bug. The common way in C++ to forward arguments is to use universal references:
25  *
26  * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
27  *
28  * but that relies on inferring the correct reference type (i.e. value vs & vs &&) from the argument.
29  * In our case, we cannot rely on the argument as supplied by the caller, because that could infer a
30  * different reference type than was used in the kernel function. The correct reference type
31  * is dictated by the kernel signature and must be identical since we cast function pointers
32  * through void* pointers and mismatches would be UB. So we need a forwarding pattern that determines
33  * the reference type to use by looking at the explicitly supplied operator signature, not by looking at
34  * the argument we're calling it with.
35  *
36  * What does std::forward do, exactly?
37  * ------------------------------------
38  * std::forward<T>(t) is a way to cast t to the reference type supplied in T.
39  * Let's assume decay_t<T> == U and T is either U or some reference of U.
40  *  - std::forward<T&>(t) will return U&, no matter what kind of reference t is.
41  *  - std::forward<T&&>(t) will return U&&, no matter what kind of reference t is.
42  *  - std::forward<T>(t) will return U&& (not U!), no matter what kind of reference t is.
43  *
44  * For universal references, that means that in the following function
45  * > template<class T> void func(T&& arg) { func2(std::forward<T>(arg)); }
46  *
47  *  - when called with arg being a rvalue reference or non-reference value, T gets inferred to be
48  *    a non-reference U, and std::forward<T>(t) will return U&&, correctly moving the argument.
49  *  - when called with arg behind a lvalue reference, T gets inferred to be U& because that's the only
50  *    way to match the signature (in C++, a type that is (T&)&& will collapse to T&).
51  *    That means std::forward<T>(t) will return U& and the value will not be moved but passed on as
52  *    a lvalue reference.
53  *
54  * How do we use that?
55  * ------------------------------------
56  * But std::forward can also be used outside of the common "universal forwarding" pattern to change
57  * reference types. So instead of following the common C++ pattern, we notice what
58  * std::forward<T>() actually does, and that is it takes a value and changes its reference to the
59  * type of reference passed in as T. If we don't infer T but explicitly specify it, we can use this
60  * to forward based on an explicitly specified reference type instead of the inferred argument type.
61  *
62  * This is why many of the dispatcher functions look like
63  * > template<class T> func(T t) { func2<T>(std::forward<T>(t)); }
64  * instead of the common
65  * > template<class T> func(T&& t) { func2(std::forward<T>(t)); }
66  *
67  * and are expected to be called by explicitly specifying the template parameters in a way that matches
68  * the expected operator signature at each call site.
69  */
70 
71 namespace impl {
72   // supported_primitive_arg_types defines which primitive types we allow in
73   // kernel functions as arguments or returns.
74   // Additionally, we support lists, dicts and optionals containing these types.
75   using supported_primitive_arg_types = guts::typelist::typelist<
76     int64_t,
77     double,
78     bool,
79     c10::string_view,
80     at::Tensor,
81     at::Scalar,
82     c10::QScheme,
83     c10::ScalarType,
84     c10::Device,
85     c10::DeviceIndex,
86     c10::Layout,
87     c10::MemoryFormat,
88     at::Dimname
89   >;
90 
91   // We have an unboxed functor in hand that takes C++ arguments, and
92   // we're building a boxed functor wrapper for it that takes IValues.
93   // So "outside" is boxed and "inside" is unboxed.
94   //
95   // So a valid input type is one that our boxed functor wrapper can
96   // unbox from an IValue into a C++ value.
97   //
98   // Whereas a valid output type is one that our wrapper can recieve
99   // as a C++ value from the unboxed functor, and box into an IValue.
100 
101   //
102   // assert_is_valid_input_type
103   // checks that T can be unboxed from an IValue into a C++ value.
104   //
105 
106   template<class T, bool AllowDeprecatedTypes, class Enable = void>
107   struct assert_is_valid_input_type {
assert_is_valid_input_typeassert_is_valid_input_type108     assert_is_valid_input_type() {
109       if constexpr (guts::typelist::contains<supported_primitive_arg_types, T>::value) {
110         /* everything is ok, this is a primitive type */
111       } else {
112         /* otherwise this must be an instance of a valid custom class, since it can only
113            have been created via IValue(x), which ensures this. */
114       }
115     }
116   };
117 
118   template<class T, bool AllowDeprecatedTypes>
119   struct assert_is_valid_input_type<std::optional<T>, AllowDeprecatedTypes>
120   : assert_is_valid_input_type<T, AllowDeprecatedTypes> {};
121 
122   template <bool AllowDeprecatedTypes, class... Args>
123   struct TypeCheckHelper;
124 
125   template <bool AllowDeprecatedTypes>
126   struct TypeCheckHelper<AllowDeprecatedTypes> {};
127 
128   template <bool AllowDeprecatedTypes, class Head, class... Rest>
129   struct TypeCheckHelper<AllowDeprecatedTypes, Head, Rest...>
130   : TypeCheckHelper<AllowDeprecatedTypes, Rest...> {
131     assert_is_valid_input_type<Head, AllowDeprecatedTypes> check;
132   };
133 
134   template<class... Contained, bool AllowDeprecatedTypes>
135   struct assert_is_valid_input_type<std::tuple<Contained...>, AllowDeprecatedTypes>
136   : TypeCheckHelper<AllowDeprecatedTypes, Contained...> {};
137 
138   template<class Key, class Value, bool AllowDeprecatedTypes>
139   struct assert_is_valid_input_type<Dict<Key, Value>, AllowDeprecatedTypes>
140   : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
141     static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
142       "You tried to register a kernel with an unsupported input type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
143   };
144 
145   template<class Key, class Value, bool AllowDeprecatedTypes>
146   struct assert_is_valid_input_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes>
147   : assert_is_valid_input_type<Value, AllowDeprecatedTypes> {
148     static_assert(AllowDeprecatedTypes,
149       "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
150     static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
151       "You tried to register a kernel with an unsupported input type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
152   };
153 
154   template<class T, bool AllowDeprecatedTypes>
155   struct assert_is_valid_input_type<List<T>, AllowDeprecatedTypes>
156   : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
157     static_assert(!std::is_same<T, at::Scalar>::value,
158       "You tried to register a kernel with an unsupported input type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
159   };
160 
161   template<class T, bool AllowDeprecatedTypes>
162   struct assert_is_valid_input_type<c10::ArrayRef<T>, AllowDeprecatedTypes>
163   : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
164     static_assert(!std::is_same<T, at::Scalar>::value,
165       "You tried to register a kernel with an unsupported input type: ArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
166   };
167 
168   template<class T, bool AllowDeprecatedTypes>
169   struct assert_is_valid_input_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes>
170   : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
171     static_assert(!std::is_same<T, at::Scalar>::value,
172       "You tried to register a kernel with an unsupported input type: OptionalArrayRef<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
173   };
174 
175   template<class T, size_t N, bool AllowDeprecatedTypes>
176   struct assert_is_valid_input_type<std::array<T, N>, AllowDeprecatedTypes>
177   : assert_is_valid_input_type<T, AllowDeprecatedTypes> {
178     static_assert(!std::is_same<T, at::Scalar>::value,
179       "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
180   };
181 
182   template<class T, bool AllowDeprecatedTypes>
183   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
184     // There is no reason to support float when we have double. Keep the API lean.
185     static_assert(guts::false_t<T>::value,
186       "You tried to register a kernel with an unsupported input type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
187   };
188   template<class T, bool AllowDeprecatedTypes>
189   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> {
190     static_assert(guts::false_t<T>::value,
191       "You tried to register a kernel with an unsupported input type: const char*. Please use c10::string_view instead.");
192   };
193   template<class T, bool AllowDeprecatedTypes>
194   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> {
195     static_assert(guts::false_t<T>::value,
196       "You tried to register a kernel with an unsupported input type: vector<bool>. Please use List<bool> instead.");
197   };
198   template<class T, bool AllowDeprecatedTypes>
199   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
200     static_assert(guts::false_t<T>::value,
201       "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
202   };
203   template<class T, bool AllowDeprecatedTypes>
204   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const c10::SymInt&, T>::value>> {
205     static_assert(guts::false_t<T>::value,
206       "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
207   };
208 
209   // TODO: it probably would be good to tighten this up quite a bit more with
210   // an explicit list for everything
211 
212   //
213   // assert_is_valid_output_type
214   //
215 
216   template<class T, bool AllowDeprecatedTypes, class Enable = void>
217   struct assert_is_valid_output_type {
218     assert_is_valid_output_type() {
219       if constexpr(guts::typelist::contains<supported_primitive_arg_types, T>::value) {
220         /* everything is ok, this is a primitive type */
221       } else {
222         /* otherwise T is verified to be a registered custom class in the IValue
223           constructor, so no benefit in double-checking here */
224       }
225     }
226   };
227 
228   template<class T, bool AllowDeprecatedTypes>
229   struct assert_is_valid_output_type<std::optional<T>, AllowDeprecatedTypes>
230   : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
231 
232   template<class T, bool AllowDeprecatedTypes>
233   struct assert_is_valid_output_type<c10::OptionalArrayRef<T>, AllowDeprecatedTypes>
234   : assert_is_valid_output_type<T, AllowDeprecatedTypes> {};
235 
236   template<class Key, class Value, bool AllowDeprecatedTypes>
237   struct assert_is_valid_output_type<Dict<Key, Value>, AllowDeprecatedTypes>
238   : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
239     static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
240       "You tried to register a kernel with an unsupported output type: Dict<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
241     static_assert(!std::is_same<Value, at::Scalar>::value,
242       "You tried to register a kernel with an unsupported output type: Dict<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
243   };
244 
245   template<class Key, class Value, bool AllowDeprecatedTypes>
246   struct assert_is_valid_output_type<std::unordered_map<Key, Value>, AllowDeprecatedTypes>
247   : assert_is_valid_output_type<Value, AllowDeprecatedTypes> {
248     static_assert(AllowDeprecatedTypes,
249       "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value>. Please use Dict<Key, Value> instead.");
250     static_assert(guts::typelist::contains<impl::valid_dict_key_types, Key>::value,
251       "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Value> where Key is invalid. We only support int64_t, double, bool, and string.");
252     static_assert(!std::is_same<Value, at::Scalar>::value,
253       "You tried to register a kernel with an unsupported output type: std::unordered_map<Key, Scalar>. Please use Dict<Key, int64_t> or Dict<Key, double>.");
254   };
255 
256   template<class T, bool AllowDeprecatedTypes>
257   struct assert_is_valid_output_type<List<T>, AllowDeprecatedTypes>
258   : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
259     static_assert(!std::is_same<T, at::Scalar>::value,
260       "You tried to register a kernel with an unsupported output type: List<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
261   };
262 
263   template<class T, bool AllowDeprecatedTypes>
264   struct assert_is_valid_output_type<std::vector<T>, AllowDeprecatedTypes>
265   : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
266     static_assert(!std::is_same<T, at::Scalar>::value,
267       "You tried to register a kernel with an unsupported output type: std::vector<Scalar>. Please use List<int64_t>, List<double> or Tensor instead.");
268     // TODO static_assert(AllowDeprecatedTypes, "You tried to register a kernel with an unsupported output type: std::vector<T>. Please use List<T> instead.");
269   };
270 
271   template<class T, size_t N, bool AllowDeprecatedTypes>
272   struct assert_is_valid_output_type<std::array<T, N>, AllowDeprecatedTypes>
273   : assert_is_valid_output_type<T, AllowDeprecatedTypes> {
274     static_assert(!std::is_same<T, at::Scalar>::value,
275       "You tried to register a kernel with an unsupported output type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
276   };
277 
278   // The following specialisations of assert_is_valid_output_type are technically not
279   // necessary since we would hit the base case and show an error message
280   // there if they didn't exist, but we can show a better error message
281   // in some common error scenarios.
282   template<class T, bool AllowDeprecatedTypes>
283   struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
284     // There is no reason to support float when we have double. Keep the API lean.
285     static_assert(guts::false_t<T>::value,
286       "You tried to register a kernel with an unsupported output type: float. Please use double instead; you should use `double` in the C++ function signature and `float` in the schema string.");
287   };
288   template<class T, bool AllowDeprecatedTypes>
289   struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const char*, T>::value>> {
290     static_assert(guts::false_t<T>::value,
291       "You tried to register a kernel with an unsupported output type: const char*. Please use c10::string_view instead.");
292   };
293   template<class T, bool AllowDeprecatedTypes>
294   struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<std::vector<bool>, T>::value>> {
295     static_assert(guts::false_t<T>::value,
296       "You tried to register a kernel with an unsupported output type: vector<bool>. Please use List<bool> instead.");
297   };
298   template<class T, bool AllowDeprecatedTypes>
299   struct assert_is_valid_output_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_integral<T>::value && !guts::typelist::contains<supported_primitive_arg_types, T>::value>> {
300     static_assert(guts::false_t<T>::value,
301       "You tried to register a kernel with an unsupported integral output type. Please use int64_t instead; you should use `int64_t` in the C++ function signature and `int` in the schema string.");
302   };
303 
304   // ivalue_to_arg
305 
306   template<class T>
307   struct decay_if_not_tensor final {
308     using type = std::decay_t<T>;
309   };
310 
311   template<>
312   struct decay_if_not_tensor<at::Tensor&> final {
313     using type = at::Tensor&;
314   };
315 
316   template<>
317   struct decay_if_not_tensor<const at::Tensor&> final {
318     using type = const at::Tensor&;
319   };
320 
321   template<class T, bool AllowDeprecatedTypes>
322   struct ivalue_to_arg final {
323     static decltype(auto) call(IValue& v) {
324       assert_is_valid_input_type<T, AllowDeprecatedTypes>();
325       return std::move(v).to<T>();
326     }
327   };
328 
329   // The following two specializations take advantage of specialized
330   // `toTensor()` overloads on IValue to avoid copying.
331   template<bool AllowDeprecatedTypes>
332   struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
333     // We cannot use the default implementation if they asked for a
334     // `at::Tensor&` because it moves from the IValue, so it can't get
335     // an lvalue reference.
336     static at::Tensor& call(IValue& v) {
337       // Tensor& is valid, don't bother asserting
338       return v.toTensor();
339     }
340   };
341 
342   template<bool AllowDeprecatedTypes>
343   struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
344     // We should not use the default implementation if they asked for
345     // a `const at::Tensor&` because it moves from the IValue and they
346     // didn't ask for that.
347     static const at::Tensor& call(IValue& v) {
348       // const Tensor& is valid, don't bother asserting
349       return v.toTensor();
350     }
351   };
352 
353   template<bool AllowDeprecatedTypes>
354   struct ivalue_to_arg<at::ITensorListRef, AllowDeprecatedTypes> final {
355     static List<at::Tensor> call(IValue& v) {
356       return v.toTensorList();
357     }
358   };
359 
360   template<class T, bool AllowDeprecatedTypes>
361   struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
362     // If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
363     // to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
364     static std::vector<T> call(IValue& v) {
365       return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
366     }
367   };
368   template<bool AllowDeprecatedTypes>
369   struct ivalue_to_arg<c10::SymIntArrayRef, AllowDeprecatedTypes> final {
370     static std::vector<c10::SymInt> call(IValue& v) {
371       if (v.isIntList()) {
372         std::vector<c10::SymInt> r;
373         auto src = v.toIntList();
374         std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
375         return r;
376       } else {
377         return ivalue_to_arg<std::vector<c10::SymInt>, AllowDeprecatedTypes>::call(v);
378       }
379     }
380   };
381   template<bool AllowDeprecatedTypes>
382   struct ivalue_to_arg<c10::OptionalArray<c10::SymInt>, AllowDeprecatedTypes> final {
383     static OptionalArray<c10::SymInt> call(IValue& v) {
384       if (v.isIntList()) {
385         std::vector<c10::SymInt> r;
386         auto src = v.toIntList();
387         std::transform(src.begin(), src.end(), std::back_inserter(r), [](int64_t i) { return c10::SymInt(i); });
388         return OptionalArray<c10::SymInt>(std::move(r));
389       } else {
390         return std::move(v).to<OptionalArray<c10::SymInt>>();
391       }
392     }
393   };
394   template<class T, bool AllowDeprecatedTypes>
395   struct ivalue_to_arg<std::optional<ArrayRef<T>>, AllowDeprecatedTypes> final {
396     // If an argument is std::optional<ArrayRef<T>>, convert the IValue to an std::optional<std::vector<T>> and pass that
397     // to the operator. OptionalArray<T> is basically a std::optional<std::vector<T>> but implicitly convertible
398     // to std::optional<ArrayRef<T>>.
399     static OptionalArray<T> call(IValue& v) {
400       return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
401     }
402   };
403 
404   template<class T, bool AllowDeprecatedTypes>
405   struct ivalue_to_arg<OptionalArrayRef<T>, AllowDeprecatedTypes> final {
406     // If an argument is OptionalArrayRef<T>, convert the IValue to an
407     // std::optional<std::vector<T>> and pass that to the operator. OptionalArray<T>
408     // is basically a std::optional<std::vector<T>> but implicitly convertible to
409     // OptionalArrayRef<T>
410     static OptionalArray<T> call(IValue& v) {
411       return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
412     }
413   };
414 
415   // return_to_ivalue
416   template<class T, bool AllowDeprecatedTypes, class Enable = void>
417   struct return_to_ivalue final {};
418 
419   template<class T, bool AllowDeprecatedTypes>
420   struct return_to_ivalue<T, AllowDeprecatedTypes, std::enable_if_t<!std::is_same<at::Tensor&, T>::value>> final {
421     static IValue call(T&& v) {
422       assert_is_valid_output_type<T, AllowDeprecatedTypes>();
423       return c10::ivalue::from(std::move(v));
424     }
425     static IValue copy(const T& v) {
426       assert_is_valid_output_type<T, AllowDeprecatedTypes>();
427       return IValue(v);
428     }
429   };
430 
431   // Special case to allow kernels to return `Tensor&`.
432   // TODO Delete this once kernels don't do that anymore
433   template<bool AllowDeprecatedTypes>
434   struct return_to_ivalue<at::Tensor&, AllowDeprecatedTypes, void> final {
435     static IValue call(at::Tensor& v) {
436       return c10::ivalue::from(v);
437     }
438     static IValue copy(at::Tensor& v) {
439       return IValue(v);
440     }
441   };
442 
443   // wrap_kernel_functor_unboxed_
444 
445   template<class KernelFunctor, class OpSignature>
446   struct wrap_kernel_functor_unboxed_ final {};
447 
448   // This specialization is for kernels with a first argument that is NOT of type DispatchKeySet
449   // This includes kernels with 0 arguments.
450   template<class KernelFunctor, class ReturnType, class... ParameterTypes>
451   struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(ParameterTypes...)> final {
452     static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value,
453       "Return type mismatch");
454     static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value,
455       "Parameter types mismatch");
456 
457     // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
458     static ReturnType call(OperatorKernel* functor, DispatchKeySet, ParameterTypes... args) {
459       KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
460       // Note [Plumbing Keys Through The Dispatcher 2]
461       // See Note [Plumbing Keys Through The Dispatcher] for the background.
462       // This functor explicitly takes in a dispatchKeySet and drops it on the floor- it does not forward it to the registered kernel.
463       //
464       // This is due to the calling convention within the dispatcher, which expects all registered kernels to have a first argument of type
465       // DispatchKeySet.
466       // This is not the case for pretty much all manually written kernels, however- this functor serves to separate the calling convention
467       // of the dispatcher from the calling convention of manually written kernels.
468       return (*functor_)(std::forward<ParameterTypes>(args)...);
469     }
470   };
471 
472   // This specialization is for kernels with a first argument of type DispatchKeySet
473   template<class KernelFunctor, class ReturnType, class... ParameterTypes>
474   struct wrap_kernel_functor_unboxed_<KernelFunctor, ReturnType(DispatchKeySet, ParameterTypes...)> final {
475     static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<KernelFunctor>::return_type>::value,
476       "Return type mismatch");
477     static_assert(std::is_same<guts::typelist::typelist<DispatchKeySet, ParameterTypes...>, typename guts::infer_function_traits_t<KernelFunctor>::parameter_types>::value,
478       "Parameter types mismatch");
479 
480     // See [Note: Argument forwarding in the dispatcher] for why ParameterTypes doesn't use &&
481     static ReturnType call(OperatorKernel* functor, DispatchKeySet dispatchKeySet, ParameterTypes... args) {
482       KernelFunctor* functor_ = static_cast<KernelFunctor*>(functor);
483       // We're explicitly taking in a dispatchKeySet and forwarding it to the registered kernel.
484       // See Note [Plumbing Keys Through The Dispatcher 2] for details.
485       return (*functor_)(dispatchKeySet, std::forward<ParameterTypes>(args)...);
486     }
487   };
488 
489   template<class KernelFunctor>
490   using wrap_kernel_functor_unboxed = wrap_kernel_functor_unboxed_<KernelFunctor, typename guts::infer_function_traits_t<KernelFunctor>::func_type>;
491 
492   // call_functor_with_args_from_stack
493 
494   template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices,  typename... ArgTypes>
495   std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
496   call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
497     (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
498 
499     // We're explicitly filtering out DispatchKeySet from the argument list.
500     // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
501     // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
502     // See Note [Plumbing Keys Through The Dispatcher] for the background.
503     return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet,
504       ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call(
505         torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
506     )...);
507   }
508 
509   template<class Functor, bool AllowDeprecatedTypes>
510   std::decay_t<typename guts::infer_function_traits_t<Functor>::return_type>
511   call_functor_with_args_from_stack(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack) {
512     // We're explicitly filtering out DispatchKeySet from the argument list.
513     // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
514     // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
515     // See Note [Plumbing Keys Through The Dispatcher] for the background.
516     using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<Functor>::parameter_types;
517     constexpr size_t num_ivalue_args = guts::typelist::size<ArgTypes>::value;
518     return call_functor_with_args_from_stack_<Functor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack, std::make_index_sequence<num_ivalue_args>(), static_cast<ArgTypes*>(nullptr));
519   }
520 
521   // push_outputs
522 
523   template<class OutputType, bool AllowDeprecatedTypes>
524   struct push_outputs final {
525     // Contrary to [Note: Argument forwarding in the dispatcher], we use OutputType&& here
526     // to avoid one extra call to the move constructor in this case. This is still not a
527     // universal reference though because OutputType is an explicitly specified class
528     // template parameter.
529     static void call(OutputType&& output, Stack* stack) {
530       torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::call(std::forward<OutputType>(output)));
531     }
532     static void copy(const OutputType& output, Stack* stack) {
533       torch::jit::push(*stack, return_to_ivalue<OutputType, AllowDeprecatedTypes>::copy(output));
534     }
535   };
536   template<class... OutputTypes, bool AllowDeprecatedTypes>
537   struct push_outputs<std::tuple<OutputTypes...>, AllowDeprecatedTypes> final {
538     static void call(std::tuple<OutputTypes...>&& output, Stack* stack) {
539       call_(std::move(output), stack, std::make_index_sequence<sizeof...(OutputTypes)>());
540     }
541     static void copy(const std::tuple<OutputTypes...>& output, Stack* stack) {
542       copy_(output, stack, std::make_index_sequence<sizeof...(OutputTypes)>());
543     }
544 
545   private:
546     template<size_t... indices>
547     static void call_(std::tuple<OutputTypes...>&& output, Stack* stack, std::index_sequence<indices...>) {
548       torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::call(std::forward<OutputTypes>(std::get<indices>(output)))...);
549     }
550     template<size_t... indices>
551     static void copy_(const std::tuple<OutputTypes...>& output, Stack* stack, std::index_sequence<indices...>) {
552       torch::jit::push(*stack, return_to_ivalue<OutputTypes, AllowDeprecatedTypes>::copy(std::get<indices>(output))...);
553     }
554   };
555   template<bool AllowDeprecatedTypes>
556   struct push_outputs<void, AllowDeprecatedTypes> final {
557     static void call(int /*dummy*/, Stack* /*stack*/) {
558     }
559     static void copy(int /*dummy*/, Stack* /*stack*/) {
560     }
561   };
562 
563   // make_boxed_from_unboxed_functor
564 
565   template<class KernelFunctor, bool AllowDeprecatedTypes>
566   struct make_boxed_from_unboxed_functor final {
567     static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value,
568       "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
569 
570     static void call(OperatorKernel* functor, const OperatorHandle&, DispatchKeySet dispatchKeySet, Stack* stack) {
571       using ReturnType = typename guts::infer_function_traits_t<KernelFunctor>::return_type;
572       // We're explicitly filtering out DispatchKeySet from the argument list.
573       // Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
574       // We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
575       // See Note [Plumbing Keys Through The Dispatcher] for the background.
576       using ArgTypes = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::parameter_types;
577       constexpr bool has_outputs = !std::is_same<void, ReturnType>::value;
578       constexpr size_t num_inputs = guts::typelist::size<ArgTypes>::value;
579       if constexpr (has_outputs) {
580         // Decay ReturnType to ReturnType_ so that if a reference gets returned, we actually store it by value
581         // and don't get a dangling reference. This is only required because some kernels still return `Tensor&`.
582         // [Note: VC++ and 'std': ambiguous symbol]
583         using ReturnType_ = ::std::decay_t<ReturnType>;
584         ReturnType_ output = call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
585         torch::jit::drop(*stack, num_inputs);
586         // See note [ VC++ and 'std': ambiguous symbol]
587         push_outputs<ReturnType_, AllowDeprecatedTypes>::call(::std::move(output), stack);
588       } else {
589         call_functor_with_args_from_stack<KernelFunctor, AllowDeprecatedTypes>(functor, dispatchKeySet, stack);
590         torch::jit::drop(*stack, num_inputs);
591       }
592     }
593   };
594 } // namespace impl
595 
596 } // namespace c10
597 
598 namespace torch {
599   using OperatorKernel = c10::OperatorKernel;
600 }
601