xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/ShaderRegistry.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/api/ShaderRegistry.h>
10 
11 namespace vkcompute {
12 namespace api {
13 
has_shader(const std::string & shader_name)14 bool ShaderRegistry::has_shader(const std::string& shader_name) {
15   const ShaderListing::const_iterator it = listings_.find(shader_name);
16   return it != listings_.end();
17 }
18 
has_dispatch(const std::string & op_name)19 bool ShaderRegistry::has_dispatch(const std::string& op_name) {
20   const Registry::const_iterator it = registry_.find(op_name);
21   return it != registry_.end();
22 }
23 
register_shader(vkapi::ShaderInfo && shader_info)24 void ShaderRegistry::register_shader(vkapi::ShaderInfo&& shader_info) {
25   if (has_shader(shader_info.kernel_name)) {
26     VK_THROW(
27         "Shader with name ", shader_info.kernel_name, "already registered");
28   }
29   listings_.emplace(shader_info.kernel_name, shader_info);
30 }
31 
register_op_dispatch(const std::string & op_name,const DispatchKey key,const std::string & shader_name)32 void ShaderRegistry::register_op_dispatch(
33     const std::string& op_name,
34     const DispatchKey key,
35     const std::string& shader_name) {
36   if (!has_dispatch(op_name)) {
37     registry_.emplace(op_name, Dispatcher());
38   }
39   const Dispatcher::const_iterator it = registry_[op_name].find(key);
40   if (it != registry_[op_name].end()) {
41     registry_[op_name][key] = shader_name;
42   } else {
43     registry_[op_name].emplace(key, shader_name);
44   }
45 }
46 
get_shader_info(const std::string & shader_name)47 const vkapi::ShaderInfo& ShaderRegistry::get_shader_info(
48     const std::string& shader_name) {
49   const ShaderListing::const_iterator it = listings_.find(shader_name);
50 
51   VK_CHECK_COND(
52       it != listings_.end(),
53       "Could not find ShaderInfo with name ",
54       shader_name);
55 
56   return it->second;
57 }
58 
shader_registry()59 ShaderRegistry& shader_registry() {
60   static ShaderRegistry registry;
61   return registry;
62 }
63 
64 } // namespace api
65 } // namespace vkcompute
66