xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/op_registration.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /**
4  * Include this file if you want to register operators. It includes all
5  * functionality needed to do so for you.
6  */
7 
8 #include <c10/core/DispatchKey.h>
9 #include <c10/core/DispatchKeySet.h>
10 #include <c10/core/CompileTimeFunctionPointer.h>
11 #include <ATen/core/boxing/KernelFunction.h>
12 #include <ATen/core/dispatch/CppSignature.h>
13 #include <ATen/core/dispatch/RegistrationHandleRAII.h>
14 #include <ATen/core/op_registration/infer_schema.h>
15 #if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
16 #include <torch/csrc/jit/frontend/function_schema_parser.h>
17 #endif
18 #include <ATen/core/ATenOpList.h>
19 
20 namespace c10 {
21 
22 namespace detail {
23 // The first argument of the schema might be of type DispatchKeySet, in which case we remove it.
24 // We do this because every argument in a function schema is expected to be convertable
25 // to an ivalue, but DispatchKeySet is not a type we want the jit to be aware of.
26 // See Note [Plumbing Keys Through The Dispatcher]
27 template<class KernelFunctor>
inferFunctionSchemaFromFunctor()28 std::unique_ptr<FunctionSchema> inferFunctionSchemaFromFunctor() {
29   using func_type = typename c10::remove_DispatchKeySet_arg_from_func<KernelFunctor>::func_type;
30   return std::make_unique<FunctionSchema>(inferFunctionSchemaFlattenedReturns<func_type>());
31 }
32 }
33 
34 /**
35  * An instance of this class handles the registration for one or more operators.
36  * Make sure you keep the RegisterOperators instance around since it will
37  * deregister the operator it's responsible for in its destructor.
38  *
39  * Example:
40  *
41  * > namespace {
42  * >   class my_kernel_cpu final : public c10::OperatorKernel {
43  * >   public:
44  * >     Tensor operator()(Tensor a, Tensor b) {...}
45  * >   };
46  * > }
47  * >
48  * > static auto registry = c10::RegisterOperators()
49  * >     .op(c10::RegisterOperators::options()
50  * >         .schema("my_op")
51  * >         .kernel<my_kernel_cpu>(DispatchKey::CPU));
52  */
53 class TORCH_API RegisterOperators final {
54 public:
55   RegisterOperators() = default;
56   ~RegisterOperators() = default;
57 
58   RegisterOperators(const RegisterOperators&) = delete;
59   RegisterOperators& operator=(const RegisterOperators&) = delete;
60   RegisterOperators(RegisterOperators&&) noexcept = default;
61   RegisterOperators& operator=(RegisterOperators&&) noexcept = default;
62 
63   class TORCH_API Options final {
64   public:
65     Options(const Options&) = delete;
66     Options(Options&&) noexcept = delete;
67     Options& operator=(const Options&) = delete;
68     Options& operator=(Options&&) noexcept = delete;
69 
70     // internal-only for registering stack based kernels
71     template<KernelFunction::BoxedKernelFunction* kernel_func>
kernel(DispatchKey dispatch_key)72     Options&& kernel(DispatchKey dispatch_key) && {
73       return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
74     }
75 
76     // internal-only for registering stack based catch-all kernels
77     template<KernelFunction::BoxedKernelFunction* kernel_func>
catchAllKernel()78     Options&& catchAllKernel() && {
79       return std::move(*this).kernel(std::nullopt, KernelFunction::makeFromBoxedFunction<kernel_func>(), std::nullopt, nullptr);
80     }
81 
82     // internal only for registering caffe2 ops
schema(FunctionSchema && schema)83     Options&& schema(FunctionSchema&& schema) {
84         TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration.");
85         schemaOrName_ = FunctionSchema(std::move(schema));
86         return std::move(*this);
87     }
88 
89     /**
90      * Use this to specify the schema for an operator. You can also specify
91      * the operator name only to have the function signature part of the
92      * schema be inferred from the kernel function.
93      *
94      * Example:
95      *
96      * > // Infer function signature from my_kernel_cpu
97      * > static auto registry = c10::RegisterOperators()
98      * >     .op(c10::RegisterOperators::options()
99      * >         .schema("my_op")
100      * >         .kernel<my_kernel_cpu>(DispatchKey::CPU));
101      * >
102      * >
103      * > // Explicitly specify full schema
104      * > static auto registry = c10::RegisterOperators()
105      * >     .op(c10::RegisterOperators::options()
106      * >         .schema("my_op(Tensor a) -> Tensor")
107      * >         .kernel<my_kernel_cpu>(DispatchKey::CPU));
108      */
schema(const std::string & schemaOrName)109     Options&& schema(const std::string& schemaOrName) {
110       TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
111 
112       #if !defined(EXPOSE_C2_OPS) && defined(CAFFE2_IS_XPLAT_BUILD)
113         throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
114       #else
115         schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName);
116       #endif
117 
118       return std::move(*this);
119     }
120 
121     /**
122      * Use this to register an operator whose kernel is implemented as a functor.
123      * The kernel is only called for inputs matching the given dispatch key.
124      * You can register multiple kernels for different dispatch keys.
125      *
126      * Example:
127      *
128      * > namespace {
129      * >   class my_kernel_cpu final : public c10::OperatorKernel {
130      * >   public:
131      * >     Tensor operator()(Tensor a, Tensor b) {...}
132      * >   };
133      * > }
134      * >
135      * > static auto registry = c10::RegisterOperators()
136      * >     .op(c10::RegisterOperators::options()
137      * >         .schema("my_op")
138      * >         .kernel<my_kernel_cpu>(DispatchKey::CPU));
139      *
140      * The functor constructor can take arguments to configure the kernel.
141      * The arguments are defined in the kernel registration.
142      * Example:
143      *
144      * > namespace {
145      * >   class my_kernel_cpu final : public c10::OperatorKernel {
146      * >   public:
147      * >     explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
148      * >         : ... {...}
149      * >
150      * >     Tensor operator()(Tensor a, Tensor b) {...}
151      * >   };
152      * > }
153      * >
154      * > static auto registry = c10::RegisterOperators()
155      * >     .op(c10::RegisterOperators::options()
156      * >         .schema("my_op")
157      * >         .kernel<my_kernel_cpu>(DispatchKey::CPU, "some_configuration", 3, true));
158      */
159     template<class KernelFunctor, class... ConstructorParameters>
160     // enable_if: only enable it if KernelFunctor is actually a functor
kernel(DispatchKey dispatch_key,ConstructorParameters &&...constructorParameters)161     std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && {
162       static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
163       static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
164 
165       return std::move(*this).kernel(
166         dispatch_key,
167         KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
168         impl::CppSignature::make<KernelFunctor>(),
169         detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
170       );
171     }
172 
173     /**
174      * Use this to register an operator whose kernel is implemented as a functor.
175      * The kernel is a catch-all kernel, meaning it's called independent from
176      * the input. Dispatch is disabled for this operator.
177      *
178      * Example:
179      *
180      * > namespace {
181      * >   class my_kernel_cpu final : public c10::OperatorKernel {
182      * >   public:
183      * >     Tensor operator()(Tensor a, Tensor b) {...}
184      * >   };
185      * > }
186      * >
187      * > static auto registry = c10::RegisterOperators()
188      * >     .op(c10::RegisterOperators::options()
189      * >         .schema("my_op")
190      * >         .catchAllKernel<my_kernel_cpu>());
191      *
192      * The functor constructor can take arguments to configure the kernel.
193      * The arguments are defined in the kernel registration.
194      * Example:
195      *
196      * > namespace {
197      * >   class my_kernel_cpu final : public c10::OperatorKernel {
198      * >   public:
199      * >     explicit my_kernel_cpu(std::string some_configuration, int a, bool b)
200      * >         : ... {...}
201      * >
202      * >     Tensor operator()(Tensor a, Tensor b) {...}
203      * >   };
204      * > }
205      * >
206      * > static auto registry = c10::RegisterOperators()
207      * >     .op(c10::RegisterOperators::options()
208      * >         .schema("my_op")
209      * >         .catchAllKernel<my_kernel_cpu>("some_configuration", 3, true));
210      */
211     template<class KernelFunctor, class... ConstructorParameters>
212     // enable_if: only enable it if KernelFunctor is actually a functor
catchAllKernel(ConstructorParameters &&...constructorParameters)213     std::enable_if_t<guts::is_functor<KernelFunctor>::value, Options&&> catchAllKernel(ConstructorParameters&&... constructorParameters) && {
214       static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to register a kernel functor using the kernel<Functor>() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");
215       static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel<Functor>(arguments...) must match one of the constructors of Functor.");
216 
217       return std::move(*this).kernel(
218         std::nullopt,
219         KernelFunction::makeFromUnboxedFunctor<false, KernelFunctor>(std::make_unique<KernelFunctor>(std::forward<ConstructorParameters>(constructorParameters)...)),
220         impl::CppSignature::make<KernelFunctor>(),
221         detail::inferFunctionSchemaFromFunctor<KernelFunctor>()
222       );
223     }
224 
225     /**
226      * Use this to register an operator whose kernel is implemented by a function.
227      * The kernel is only called for inputs matching the given dispatch key.
228      * You can register multiple kernels for different dispatch keys.
229      *
230      * Example:
231      *
232      * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
233      * >
234      * > static auto registry = c10::RegisterOperators()
235      * >     .op(c10::RegisterOperators::options()
236      * >         .schema("my_op")
237      * >         .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>(DispatchKey::CPU));
238      */
239     template<class FuncType, FuncType* kernel_func>
240     // enable_if: only enable it if FuncType is actually a function
kernel(DispatchKey dispatch_key)241     std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key) && {
242       static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
243       static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
244 
245       return std::move(*this).kernel(
246         dispatch_key,
247         KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
248         impl::CppSignature::make<FuncType>(),
249         // TODO Do schema inference without relying on WrapFunctionIntoFunctor
250         detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
251       );
252     }
253 
254     /**
255      * Use this to register an operator whose kernel is implemented by a function.
256      * The kernel is a catch-all kernel, meaning it's called independent from
257      * the input. Dispatch is disabled for this operator.
258      *
259      * Example:
260      *
261      * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
262      * >
263      * > static auto registry = c10::RegisterOperators()
264      * >     .op(c10::RegisterOperators::options()
265      * >         .schema("my_op")
266      * >         .catchAllKernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
267      */
268     template<class FuncType, FuncType* kernel_func>
269     // enable_if: only enable it if FuncType is actually a function
catchAllKernel()270     std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel() && {
271       static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
272       static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr");
273 
274       return std::move(*this).kernel(
275         std::nullopt,
276         KernelFunction::makeFromUnboxedFunction(TORCH_FN(kernel_func)),
277         impl::CppSignature::make<FuncType>(),
278         // TODO Do schema inference without relying on WrapFunctionIntoFunctor
279         detail::inferFunctionSchemaFromFunctor<typename impl::WrapFunctionIntoFunctor<CompileTimeFunctionPointer<FuncType, kernel_func>>::type>()
280       );
281     }
282 
283     template<class FuncType>
284     // enable_if: only enable it if FuncType is actually a function
kernel(DispatchKey dispatch_key,FuncType * kernel_func)285     std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && {
286       static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
287       TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
288 
289       return std::move(*this).kernel(
290         dispatch_key,
291         KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
292         impl::CppSignature::make<FuncType>(),
293         // TODO Do schema inference without relying on WrapFunctionIntoFunctor
294         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
295       );
296     }
297 
298     template<class FuncType>
299     // enable_if: only enable it if FuncType is actually a function
catchAllKernel(FuncType * kernel_func)300     std::enable_if_t<guts::is_function_type<FuncType>::value, Options&&> catchAllKernel(FuncType* kernel_func) && {
301       static_assert(!std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API.");
302       TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr");
303 
304       return std::move(*this).kernel(
305         std::nullopt,
306         KernelFunction::makeFromUnboxedRuntimeFunction(kernel_func),
307         impl::CppSignature::make<FuncType>(),
308         // TODO Do schema inference without relying on WrapFunctionIntoFunctor
309         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
310       );
311     }
312 
313     /**
314      * Use this to register an operator whose kernel is implemented as a lambda.
315      * The kernel is only called for inputs matching the given dispatch key.
316      * You can register multiple kernels for different dispatch keys.
317      *
318      * The lambda must be stateless, i.e. not have a capture. If your kernel
319      * needs to store some configuration parameters, write the kernel as a
320      * functor instead.
321      *
322      * Example:
323      *
324      * > static auto registry = c10::RegisterOperators()
325      * >     .op(c10::RegisterOperators::options()
326      * >         .schema("my_op")
327      * >         .kernel(DispatchKey::CPU, [] (Tensor a) -> Tensor {...}));
328      */
329     template<class Lambda>
330     // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
331     std::enable_if_t<
332         guts::is_functor<std::decay_t<Lambda>>::value
333         && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
kernel(DispatchKey dispatch_key,Lambda && functor)334         Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && {
335       static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
336 
337       // We don't support stateful lambdas (i.e. lambdas with a capture), because their
338       // behavior would be nonobvious. A functor kernel with cache gets a new instance of
339       // its cache each time the kernel is looked up from the dispatch table.
340       // A lambda with a capture would be global and share its capture between all kernel lookups.
341       // So, instead of making users having to think about it (including the thread-safety
342       // issues this causes), let's just forbid stateful lambdas altogether.
343       static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
344 
345       return std::move(*this).kernel(
346         dispatch_key,
347         KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(functor)),
348         impl::CppSignature::make<Lambda>(),
349         // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
350         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
351       );
352     }
353 
354     /**
355      * Use this to register an operator whose kernel is implemented as a lambda.
356      * The kernel is a catch-all kernel, meaning it's called independent from
357      * the input. Dispatch is disabled for this operator.
358      *
359      * The lambda must be stateless, i.e. not have a capture. If your kernel
360      * needs to store some configuration parameters, write the kernel as a
361      * functor instead.
362      *
363      * Example:
364      *
365      * > static auto registry = c10::RegisterOperators()
366      * >     .op(c10::RegisterOperators::options()
367      * >         .schema("my_op")
368      * >         .catchAllKernel([] (Tensor a) -> Tensor {...}));
369      */
370     template<class Lambda>
371     // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
372     std::enable_if_t<
373         guts::is_functor<std::decay_t<Lambda>>::value
374         && !std::is_same<typename guts::infer_function_traits_t<std::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
catchAllKernel(Lambda && lambda)375         Options&&> catchAllKernel(Lambda&& lambda) && {
376       static_assert(!std::is_base_of<OperatorKernel, std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
377 
378       // We don't support stateful lambdas (i.e. lambdas with a capture), because their
379       // behavior would be nonobvious.
380       // A lambda with a capture would be global and share its capture between all kernel lookups.
381       // This would be a likely source for unexpected race conditions, so we forbid it.
382       // If a kernel really needs global state, they can just have regular global state
383       // in their .cpp file next to the kernel lambda.
384       static_assert(guts::is_stateless_lambda<std::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel only works for stateless lambdas (i.e. lambdas without captures). If you need a cache, please use the functor based API kernel<Functor>() instead.");
385 
386       return std::move(*this).kernel(
387         std::nullopt,
388         KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(lambda)),
389         impl::CppSignature::make<Lambda>(),
390         // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
391         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
392       );
393     }
394 
aliasAnalysis(AliasAnalysisKind aliasAnalysisKind)395     Options&& aliasAnalysis(AliasAnalysisKind aliasAnalysisKind) && {
396       TORCH_CHECK(!aliasAnalysisKind_.has_value(), "You can only call aliasAnalysis() once per operator registration.");
397       aliasAnalysisKind_ = aliasAnalysisKind;
398       return std::move(*this);
399     }
400 
401   private:
kernel(std::optional<DispatchKey> dispatch_key,KernelFunction && func,std::optional<impl::CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> && inferred_function_schema)402     Options&& kernel(std::optional<DispatchKey> dispatch_key, KernelFunction&& func, std::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema>&& inferred_function_schema) && {
403       KernelRegistrationConfig config;
404       config.dispatch_key = dispatch_key;
405       config.func = std::move(func);
406       config.cpp_signature = cpp_signature;
407       config.inferred_function_schema = std::move(inferred_function_schema);
408       kernels.push_back(std::move(config));
409       return std::move(*this);
410     }
411 
Options()412     Options()
413     : schemaOrName_(std::nullopt)
414     , kernels()
415     , aliasAnalysisKind_(std::nullopt)
416     {}
417 
418     // KernelRegistrationConfig accumulates all information from the config
419     // parameters passed to a RegisterOperators::op() call into one object.
420     struct KernelRegistrationConfig final {
KernelRegistrationConfigfinal421       KernelRegistrationConfig()
422         : dispatch_key(std::nullopt)
423         , func()
424         , cpp_signature(std::nullopt)
425         , inferred_function_schema(nullptr)
426       {}
427 
428       std::optional<DispatchKey> dispatch_key;
429       KernelFunction func;
430       std::optional<impl::CppSignature> cpp_signature;
431       std::unique_ptr<FunctionSchema> inferred_function_schema;
432     };
433 
434     std::optional<std::variant<OperatorName, FunctionSchema>> schemaOrName_;
435 
436     std::vector<KernelRegistrationConfig> kernels;
437     std::optional<AliasAnalysisKind> aliasAnalysisKind_;
438     friend class RegisterOperators;
439     friend class Library;
440   };
441 
442   /**
443    * Call this to get an instance of registration options, which
444    * can be passed to a call to RegisterOperators::op() to specify
445    * these options for the operator registration.
446    * See class doc comment for examples.
447    */
options()448   static Options options() {
449     return {};
450   }
451 
452   /**
453    * Call this to register an operator. See class doc comment for examples.
454    */
op(Options && options)455   RegisterOperators&& op(Options&& options) && {
456     checkSchemaAndRegisterOp_(std::move(options));
457     return std::move(*this);
458   }
459 
460   // Regular mutator version of the && version above
op(Options && options)461   RegisterOperators& op(Options&& options) & {
462     checkSchemaAndRegisterOp_(std::move(options));
463     return *this;
464   }
465 
466   /**
467    * This is a shorthand for RegisterOperators::op(Options) where you can
468    * specify the operator schema outside of the options parameter.
469    * See class doc comment for examples.
470    */
471   RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && {
472     return std::move(*this).op(std::move(options).schema(schemaOrName));
473   }
474 
475   // internal only for registering caffe2 ops
op(FunctionSchema schema,Options && options)476   RegisterOperators&& op(FunctionSchema schema, Options&& options) && {
477     return std::move(*this).op(std::move(options).schema(std::move(schema)));
478   }
479 
480   template<class FuncType>
481   explicit RegisterOperators(const std::string& schemaOrName, FuncType&& func, Options&& options = RegisterOperators::options())
RegisterOperators()482   : RegisterOperators() {
483     std::move(*this).op(schemaOrName, std::forward<FuncType>(func), std::move(options));
484   }
485 
486   /**
487    * This API registers an operator based on a kernel function pointer.
488    *
489    * Given a kernel
490    *
491    * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} }
492    *
493    * This API looks like:
494    *
495    * > static auto registry = c10::RegisterOperators()
496    * >     .op("my_op", &my_kernel_cpu);
497    *
498    * If your kernel is small and the overhead of calling it matters,
499    * then this API might be the wrong choice since the following API
500    * has a slightly lower overhead for calling into the kernel:
501    *
502    * > static auto registry = c10::RegisterOperators()
503    * >     .op("my_op", c10::RegisterOperators::options()
504    * >         .kernel<decltype(my_kernel_cpu), &my_kernel_cpu>());
505    *
506    * Or, alternatively, write your kernel as a functor:
507    *
508    * > namespace {
509    * >   class my_kernel_cpu final : public c10::OperatorKernel {
510    * >   public:
511    * >     Tensor operator()(Tensor a, Tensor b) {...}
512    * >   };
513    * > }
514    * >
515    * > static auto registry = c10::RegisterOperators()
516    * >     .op("my_op", c10::RegisterOperators::options()
517    * >         .kernel<my_kernel_cpu>());
518    */
519    template<class FuncType>
520    // enable_if: only enable it if FuncType is actually a function, but not a stack based BoxedKernelFunction.
521    std::enable_if_t<guts::is_function_type<FuncType>::value && !std::is_same<FuncType, KernelFunction::BoxedKernelFunction>::value, RegisterOperators&&>
522    op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && {
523      constexpr bool AllowLegacyTypes = true;
524      return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
525        std::nullopt,
526        KernelFunction::makeFromUnboxedRuntimeFunction<AllowLegacyTypes>(func),
527        impl::CppSignature::make<FuncType>(),
528        // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
529        detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>()
530      ));
531    }
532 
533    /**
534     * This API registers an operator based on a kernel lambda.
535     *
536     * This API looks like:
537     *
538     * > static auto registry = c10::RegisterOperators()
539     * >     .op("my_op", [] (Tensor a, Tensor b) {...});
540     *
541     * This is equivalent to:
542     *
543     * > static auto registry = c10::RegisterOperators()
544     * >     .op("my_op", c10::RegisterOperators::options()
545     * >         .catchAllKernel([] (Tensor a, Tensor b) {...}));
546     *
547     */
548     template<class Lambda>
549     // enable_if: only enable it if Lambda is actually a stateless lambda
550     std::enable_if_t<guts::is_functor<Lambda>::value && guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
551     op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
552       static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
553 
554       constexpr bool AllowLegacyTypes = true;
555       return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
556         std::nullopt,
557         KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
558         impl::CppSignature::make<Lambda>(),
559         // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
560         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
561       ));
562     }
563 
564     template<class Lambda>
565     C10_DEPRECATED_MESSAGE("Registering operator kernels with stateful lambdas (i.e. lambdas with a capture) has non-obvious behavior. This is deprecated. Please use a lambda without a capture or a functor class instead.")
566     // enable_if: only enable it if Lambda is actually a functor but not a stateless lambda
567     std::enable_if_t<guts::is_functor<Lambda>::value && !guts::is_stateless_lambda<std::decay_t<Lambda>>::value, RegisterOperators&&>
568     op(const std::string& schemaOrName, Lambda&& lambda, Options&& options = RegisterOperators::options()) && {
569       static_assert(!std::is_base_of<OperatorKernel, Lambda>::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead.");
570 
571       constexpr bool AllowLegacyTypes = true;
572       return std::move(*this).op(std::move(options).schema(schemaOrName).kernel(
573         std::nullopt,
574         KernelFunction::makeFromUnboxedLambda<AllowLegacyTypes>(std::forward<Lambda>(lambda)),
575         impl::CppSignature::make<Lambda>(),
576         // TODO Do schema inference without relying on WrapFunctionIntoRuntimeFunctor
577         detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>()
578       ));
579     }
580 
581 private:
582   void checkSchemaAndRegisterOp_(Options&& config);
583 
584   static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
585   void checkNoDuplicateKernels_(const Options& options);
586   void registerOp_(Options&& options);
587 
588   std::vector<RegistrationHandleRAII> registrars_;
589 };
590 
591 } // namespace c10
592 
593 namespace torch {
594   // Old-style API
595   using RegisterOperators = c10::RegisterOperators;
596 }
597