xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Shader.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h>
14 
15 #include <executorch/backends/vulkan/runtime/utils/VecUtils.h>
16 
17 #include <executorch/backends/vulkan/runtime/vk_api/Types.h>
18 
19 #include <mutex>
20 #include <unordered_map>
21 
22 namespace vkcompute {
23 namespace vkapi {
24 
25 class ShaderLayout final {
26  public:
27   using Signature = std::vector<VkDescriptorType>;
28 
29   explicit ShaderLayout(VkDevice, const Signature&);
30 
31   ShaderLayout(const ShaderLayout&) = delete;
32   ShaderLayout& operator=(const ShaderLayout&) = delete;
33 
34   ShaderLayout(ShaderLayout&&) noexcept;
35   ShaderLayout& operator=(ShaderLayout&&) = delete;
36 
37   ~ShaderLayout();
38 
39  private:
40   VkDevice device_;
41   VkDescriptorSetLayout handle_;
42 
43  public:
handle()44   VkDescriptorSetLayout handle() const {
45     return handle_;
46   }
47 
48   // We need to define a custom swap function since this class
49   // does not allow for move assignment. The swap function will
50   // be used in the hash map.
51   friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept;
52 };
53 
54 struct ShaderInfo final {
55   struct {
56     const uint32_t* bin = nullptr;
57     uint32_t size = 0u;
58   } src_code;
59 
60   std::string kernel_name{""};
61   ShaderLayout::Signature kernel_layout{};
62 
63   // Shader Metadata
64   utils::uvec3 out_tile_size{1u, 1u, 1u};
65 
66   explicit ShaderInfo();
67 
68   explicit ShaderInfo(
69       std::string,
70       const uint32_t*,
71       const uint32_t,
72       std::vector<VkDescriptorType>,
73       const utils::uvec3 tile_size);
74 
75   operator bool() const {
76     return src_code.bin != nullptr;
77   };
78 };
79 
80 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);
81 
82 class ShaderModule final {
83  public:
84   explicit ShaderModule(VkDevice device, const ShaderInfo& source);
85 
86   ShaderModule(const ShaderModule&) = delete;
87   ShaderModule& operator=(const ShaderModule&) = delete;
88 
89   ShaderModule(ShaderModule&&) noexcept;
90   ShaderModule& operator=(ShaderModule&&) = delete;
91 
92   ~ShaderModule();
93 
94  private:
95   VkDevice device_;
96   VkShaderModule handle_;
97 
98  public:
handle()99   inline VkShaderModule handle() const {
100     return handle_;
101   }
102 
103   // We need to define a custom swap function since this class
104   // does not allow for move assignment. The swap function will
105   // be used in the hash map.
106   friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept;
107 };
108 
109 class ShaderLayoutCache final {
110  public:
111   explicit ShaderLayoutCache(VkDevice device);
112 
113   ShaderLayoutCache(const ShaderLayoutCache&) = delete;
114   ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete;
115 
116   ShaderLayoutCache(ShaderLayoutCache&&) noexcept;
117   ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete;
118 
119   ~ShaderLayoutCache();
120 
121   using Key = ShaderLayout::Signature;
122   using Value = ShaderLayout;
123 
124   struct Hasher {
operatorHasher125     inline size_t operator()(const ShaderLayout::Signature& signature) const {
126       size_t hashed = 0u;
127 
128       for (const VkDescriptorType type : signature) {
129         hashed =
130             utils::hash_combine(hashed, std::hash<VkDescriptorType>()(type));
131       }
132 
133       return hashed;
134     }
135   };
136 
137  private:
138   // Multiple threads could potentially be adding entries into the cache, so use
139   // a mutex to manage access
140   std::mutex cache_mutex_;
141 
142   VkDevice device_;
143   std::unordered_map<Key, Value, Hasher> cache_;
144 
145  public:
146   VkDescriptorSetLayout retrieve(const Key&);
147   void purge();
148 };
149 
150 class ShaderCache final {
151  public:
152   explicit ShaderCache(VkDevice device);
153 
154   ShaderCache(const ShaderCache&) = delete;
155   ShaderCache& operator=(const ShaderCache&) = delete;
156 
157   ShaderCache(ShaderCache&&) noexcept;
158   ShaderCache& operator=(ShaderCache&&) = delete;
159 
160   ~ShaderCache();
161 
162   using Key = ShaderInfo;
163   using Value = ShaderModule;
164 
165   struct Hasher {
operatorHasher166     inline size_t operator()(const ShaderInfo& source) const {
167       size_t seed = 0;
168       seed = utils::hash_combine(
169           seed, std::hash<const uint32_t*>()(source.src_code.bin));
170       seed = utils::hash_combine(
171           seed, std::hash<uint32_t>()(source.src_code.size));
172 
173       return seed;
174     }
175   };
176 
177  private:
178   // Multiple threads could potentially be adding entries into the cache, so use
179   // a mutex to manage access
180   std::mutex cache_mutex_;
181 
182   VkDevice device_;
183   std::unordered_map<Key, Value, Hasher> cache_;
184 
185  public:
186   VkShaderModule retrieve(const Key&);
187   void purge();
188 };
189 
190 } // namespace vkapi
191 } // namespace vkcompute
192 
193 inline bool operator==(
194     const VkDescriptorSetLayoutBinding& _1,
195     const VkDescriptorSetLayoutBinding& _2) {
196   return (
197       _1.binding == _2.binding && _1.descriptorType == _2.descriptorType &&
198       _1.descriptorCount == _2.descriptorCount &&
199       _1.stageFlags == _2.stageFlags &&
200       _1.pImmutableSamplers == _2.pImmutableSamplers);
201 }
202