xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/nnc/registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/Registry.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace mobile {
9 namespace nnc {
10 
11 using nnc_kernel_function_type = int(void**);
12 
13 struct TORCH_API NNCKernel {
14   virtual ~NNCKernel() = default;
15   virtual int execute(void** /* args */) = 0;
16 };
17 
18 TORCH_DECLARE_REGISTRY(NNCKernelRegistry, NNCKernel);
19 
20 #define REGISTER_NNC_KERNEL(id, kernel, ...)     \
21   extern "C" {                                   \
22   nnc_kernel_function_type kernel;               \
23   }                                              \
24   struct NNCKernel_##kernel : public NNCKernel { \
25     int execute(void** args) override {          \
26       return kernel(args);                       \
27     }                                            \
28   };                                             \
29   C10_REGISTER_TYPED_CLASS(NNCKernelRegistry, id, NNCKernel_##kernel);
30 
31 namespace registry {
32 
has_nnc_kernel(const std::string & id)33 inline bool has_nnc_kernel(const std::string& id) {
34   return NNCKernelRegistry()->Has(id);
35 }
36 
get_nnc_kernel(const std::string & id)37 inline std::unique_ptr<NNCKernel> get_nnc_kernel(const std::string& id) {
38   return NNCKernelRegistry()->Create(id);
39 }
40 
41 } // namespace registry
42 
43 } // namespace nnc
44 } // namespace mobile
45 } // namespace jit
46 } // namespace torch
47