xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/ShaderRegistry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/Shader.h>
8 
9 #include <string>
10 #include <unordered_map>
11 
12 #define VK_KERNEL(shader_name) \
13   ::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name)
14 
15 #define VK_KERNEL_FROM_STR(shader_name_str) \
16   ::at::native::vulkan::api::shader_registry().get_shader_info(shader_name_str)
17 
18 namespace at {
19 namespace native {
20 namespace vulkan {
21 namespace api {
22 
23 enum class DispatchKey : int8_t {
24   CATCHALL,
25   ADRENO,
26   MALI,
27   OVERRIDE,
28 };
29 
30 class ShaderRegistry final {
31   using ShaderListing = std::unordered_map<std::string, ShaderInfo>;
32   using Dispatcher = std::unordered_map<DispatchKey, std::string>;
33   using Registry = std::unordered_map<std::string, Dispatcher>;
34 
35   ShaderListing listings_;
36   Dispatcher dispatcher_;
37   Registry registry_;
38 
39  public:
40   /*
41    * Check if the registry has a shader registered under the given name
42    */
43   bool has_shader(const std::string& shader_name);
44 
45   /*
46    * Check if the registry has a dispatch registered under the given name
47    */
48   bool has_dispatch(const std::string& op_name);
49 
50   /*
51    * Register a ShaderInfo to a given shader name
52    */
53   void register_shader(ShaderInfo&& shader_info);
54 
55   /*
56    * Register a dispatch entry to the given op name
57    */
58   void register_op_dispatch(
59       const std::string& op_name,
60       const DispatchKey key,
61       const std::string& shader_name);
62 
63   /*
64    * Given a shader name, return the ShaderInfo which contains the SPIRV binary
65    */
66   const ShaderInfo& get_shader_info(const std::string& shader_name);
67 };
68 
69 class ShaderRegisterInit final {
70   using InitFn = void();
71 
72  public:
ShaderRegisterInit(InitFn * init_fn)73   ShaderRegisterInit(InitFn* init_fn) {
74     init_fn();
75   };
76 };
77 
78 // The global shader registry is retrieved using this function, where it is
79 // declared as a static local variable.
80 ShaderRegistry& shader_registry();
81 
82 } // namespace api
83 } // namespace vulkan
84 } // namespace native
85 } // namespace at
86 
87 #endif /* USE_VULKAN_API */
88