1 #pragma once 2 3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 4 5 #ifdef USE_VULKAN_API 6 7 #include <ATen/native/vulkan/api/Shader.h> 8 9 #include <string> 10 #include <unordered_map> 11 12 #define VK_KERNEL(shader_name) \ 13 ::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name) 14 15 #define VK_KERNEL_FROM_STR(shader_name_str) \ 16 ::at::native::vulkan::api::shader_registry().get_shader_info(shader_name_str) 17 18 namespace at { 19 namespace native { 20 namespace vulkan { 21 namespace api { 22 23 enum class DispatchKey : int8_t { 24 CATCHALL, 25 ADRENO, 26 MALI, 27 OVERRIDE, 28 }; 29 30 class ShaderRegistry final { 31 using ShaderListing = std::unordered_map<std::string, ShaderInfo>; 32 using Dispatcher = std::unordered_map<DispatchKey, std::string>; 33 using Registry = std::unordered_map<std::string, Dispatcher>; 34 35 ShaderListing listings_; 36 Dispatcher dispatcher_; 37 Registry registry_; 38 39 public: 40 /* 41 * Check if the registry has a shader registered under the given name 42 */ 43 bool has_shader(const std::string& shader_name); 44 45 /* 46 * Check if the registry has a dispatch registered under the given name 47 */ 48 bool has_dispatch(const std::string& op_name); 49 50 /* 51 * Register a ShaderInfo to a given shader name 52 */ 53 void register_shader(ShaderInfo&& shader_info); 54 55 /* 56 * Register a dispatch entry to the given op name 57 */ 58 void register_op_dispatch( 59 const std::string& op_name, 60 const DispatchKey key, 61 const std::string& shader_name); 62 63 /* 64 * Given a shader name, return the ShaderInfo which contains the SPIRV binary 65 */ 66 const ShaderInfo& get_shader_info(const std::string& shader_name); 67 }; 68 69 class ShaderRegisterInit final { 70 using InitFn = void(); 71 72 public: ShaderRegisterInit(InitFn * init_fn)73 ShaderRegisterInit(InitFn* init_fn) { 74 init_fn(); 75 }; 76 }; 77 78 // The global shader registry is retrieved using this function, where it is 79 // declared as a static local variable. 80 ShaderRegistry& shader_registry(); 81 82 } // namespace api 83 } // namespace vulkan 84 } // namespace native 85 } // namespace at 86 87 #endif /* USE_VULKAN_API */ 88