1 #pragma once 2 3 #include <c10/util/Exception.h> 4 #include <c10/util/Registry.h> 5 6 namespace torch { 7 namespace jit { 8 namespace mobile { 9 namespace nnc { 10 11 using nnc_kernel_function_type = int(void**); 12 13 struct TORCH_API NNCKernel { 14 virtual ~NNCKernel() = default; 15 virtual int execute(void** /* args */) = 0; 16 }; 17 18 TORCH_DECLARE_REGISTRY(NNCKernelRegistry, NNCKernel); 19 20 #define REGISTER_NNC_KERNEL(id, kernel, ...) \ 21 extern "C" { \ 22 nnc_kernel_function_type kernel; \ 23 } \ 24 struct NNCKernel_##kernel : public NNCKernel { \ 25 int execute(void** args) override { \ 26 return kernel(args); \ 27 } \ 28 }; \ 29 C10_REGISTER_TYPED_CLASS(NNCKernelRegistry, id, NNCKernel_##kernel); 30 31 namespace registry { 32 has_nnc_kernel(const std::string & id)33inline bool has_nnc_kernel(const std::string& id) { 34 return NNCKernelRegistry()->Has(id); 35 } 36 get_nnc_kernel(const std::string & id)37inline std::unique_ptr<NNCKernel> get_nnc_kernel(const std::string& id) { 38 return NNCKernelRegistry()->Create(id); 39 } 40 41 } // namespace registry 42 43 } // namespace nnc 44 } // namespace mobile 45 } // namespace jit 46 } // namespace torch 47