xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Pipeline.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h>
14 
15 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h>
16 
17 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h>
18 #include <executorch/backends/vulkan/runtime/vk_api/memory/Image.h>
19 
20 #include <mutex>
21 #include <unordered_map>
22 
23 #define SV(x) ::vkcompute::vkapi::SpecVar(x)
24 
25 namespace vkcompute {
26 namespace vkapi {
27 
28 struct SpecVar final {
29   enum class Type : uint8_t {
30     FLOAT,
31     INT,
32     UINT,
33     BOOL,
34   };
35 
36   union Value {
37     int32_t as_int32;
38     uint32_t as_uint32;
39     float as_float;
40     bool as_bool;
41   };
42 
43   Value value;
44   Type type;
45 
46   SpecVar();
47   SpecVar(const float val);
48   SpecVar(const int32_t val);
49   SpecVar(const uint32_t val);
50   SpecVar(const bool val);
51 
52   uint32_t val_size() const;
53   uint32_t val_offset() const;
54 };
55 
56 bool operator==(const SpecVar& lhs, const SpecVar& rhs);
57 
58 bool operator!=(const SpecVar& lhs, const SpecVar& rhs);
59 
60 class SpecVarList final {
61   std::vector<SpecVar> vars;
62 
63  public:
64   SpecVarList();
65   SpecVarList(std::initializer_list<SpecVar> init_list);
66 
at(const size_t index)67   inline const SpecVar& at(const size_t index) const {
68     return vars.at(index);
69   }
70 
data()71   inline const SpecVar* data() const {
72     return vars.data();
73   }
74 
size()75   inline uint32_t size() const {
76     return utils::safe_downcast<uint32_t>(vars.size());
77   }
78 
data_nbytes()79   inline uint32_t data_nbytes() const {
80     return vars.size() * sizeof(SpecVar);
81   }
82 
83   void append(const SpecVarList& other);
84 
85   std::vector<VkSpecializationMapEntry> generate_map_entries() const;
86 
87   friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
88 };
89 
90 bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
91 
92 struct PipelineBarrier final {
93   struct Stages final {
94     VkPipelineStageFlags src;
95     VkPipelineStageFlags dst;
96   } stage;
97 
98   std::vector<BufferMemoryBarrier> buffers;
99   std::vector<ImageMemoryBarrier> images;
100   std::vector<VkBufferMemoryBarrier> buffer_barrier_handles;
101   std::vector<VkImageMemoryBarrier> image_barrier_handles;
102 
103   inline operator bool() const {
104     return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() ||
105         !images.empty();
106   }
107 };
108 
109 using PipelineStageFlags = uint8_t;
110 
111 enum PipelineStage : PipelineStageFlags {
112   NO_STAGE = 0u << 0u,
113   COMPUTE = 1u << 0u,
114   HOST = 1u << 1u,
115   TRANSFER = 1u << 2u,
116 };
117 
118 VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags);
119 VkPipelineStageFlags vk_stage(const PipelineStageFlags);
120 VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags);
121 
122 class PipelineLayout final {
123  public:
124   explicit PipelineLayout(VkDevice, VkDescriptorSetLayout);
125 
126   PipelineLayout(const PipelineLayout&) = delete;
127   PipelineLayout& operator=(const PipelineLayout&) = delete;
128 
129   PipelineLayout(PipelineLayout&&) noexcept;
130   PipelineLayout& operator=(PipelineLayout&&) = delete;
131 
132   ~PipelineLayout();
133 
134  private:
135   VkDevice device_;
136   VkPipelineLayout handle_;
137 
138  public:
handle()139   VkPipelineLayout handle() const {
140     return handle_;
141   }
142 
143   // We need to define a custom swap function since this class
144   // does not allow for move assignment. The swap function will
145   // be used in the hash map.
146   friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept;
147 };
148 
149 class ComputePipeline final {
150  public:
151   struct Descriptor final {
152     VkPipelineLayout pipeline_layout;
153     VkShaderModule shader_module;
154     SpecVarList specialization_constants;
155   };
156 
157   explicit ComputePipeline(
158       VkDevice device,
159       const Descriptor& descriptor,
160       VkPipelineCache pipeline_cache);
161 
162   ComputePipeline(const ComputePipeline&) = delete;
163   ComputePipeline& operator=(const ComputePipeline&) = delete;
164 
165   ComputePipeline(ComputePipeline&&) noexcept;
166   ComputePipeline& operator=(ComputePipeline&&) = delete;
167 
168   ~ComputePipeline();
169 
170  private:
171   VkDevice device_;
172   VkPipeline handle_;
173 
174  public:
handle()175   inline VkPipeline handle() const {
176     return handle_;
177   }
178 
179   // We need to define a custom swap function since this class
180   // does not allow for move assignment. The swap function will
181   // be used in the hash map.
182   friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept;
183 };
184 
185 class PipelineLayoutCache final {
186  public:
187   explicit PipelineLayoutCache(VkDevice device);
188 
189   PipelineLayoutCache(const PipelineLayoutCache&) = delete;
190   PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete;
191 
192   PipelineLayoutCache(PipelineLayoutCache&&) noexcept;
193   PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete;
194 
195   ~PipelineLayoutCache();
196 
197   using Key = VkDescriptorSetLayout;
198   using Value = PipelineLayout;
199 
200   struct Hasher {
operatorHasher201     inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const {
202       return std::hash<VkDescriptorSetLayout>()(descriptor_layout);
203     }
204   };
205 
206  private:
207   // Multiple threads could potentially be adding entries into the cache, so use
208   // a mutex to manage access
209   std::mutex cache_mutex_;
210 
211   VkDevice device_;
212   std::unordered_map<Key, Value, Hasher> cache_;
213 
214  public:
215   VkPipelineLayout retrieve(const Key&);
216   void purge();
217 };
218 
219 class ComputePipelineCache final {
220  public:
221   explicit ComputePipelineCache(
222       VkDevice device,
223       const std::string& cache_data_path);
224 
225   ComputePipelineCache(const ComputePipelineCache&) = delete;
226   ComputePipelineCache& operator=(const ComputePipelineCache&) = delete;
227 
228   ComputePipelineCache(ComputePipelineCache&&) noexcept;
229   ComputePipelineCache& operator=(ComputePipelineCache&&) = delete;
230 
231   ~ComputePipelineCache();
232 
233   using Key = ComputePipeline::Descriptor;
234   using Value = ComputePipeline;
235 
236   struct Hasher {
operatorHasher237     inline size_t operator()(
238         const ComputePipeline::Descriptor& descriptor) const {
239       size_t seed = 0;
240       seed = utils::hash_combine(
241           seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout));
242       seed = utils::hash_combine(
243           seed, std::hash<VkShaderModule>()(descriptor.shader_module));
244 
245       const SpecVarList& spec_vars = descriptor.specialization_constants;
246       seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size()));
247 
248       for (int i = 0; i < spec_vars.size(); ++i) {
249         const SpecVar& spec_var = spec_vars.at(i);
250         size_t new_seed = 0;
251         switch (spec_var.type) {
252           case SpecVar::Type::FLOAT:
253             new_seed = std::hash<float>()(spec_var.value.as_float);
254             break;
255           case SpecVar::Type::INT:
256             new_seed = std::hash<int32_t>()(spec_var.value.as_int32);
257             break;
258           case SpecVar::Type::UINT:
259             new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32);
260             break;
261           case SpecVar::Type::BOOL:
262             new_seed = std::hash<bool>()(spec_var.value.as_bool);
263             break;
264         }
265         seed = utils::hash_combine(seed, new_seed);
266       }
267 
268       return seed;
269     }
270   };
271 
272   void save_cache();
273 
274  private:
275   std::vector<char> load_cache();
276 
277   // Multiple threads could potentially be adding entries into the cache, so use
278   // a mutex to manage access
279   std::mutex cache_mutex_;
280 
281   VkDevice device_;
282   VkPipelineCache pipeline_cache_;
283   std::unordered_map<Key, Value, Hasher> cache_;
284   const std::string cache_data_path_;
285 
286  public:
287   VkPipeline retrieve(const Key&);
288   void purge();
289 };
290 
291 //
292 // Impl
293 //
294 
295 } // namespace vkapi
296 } // namespace vkcompute
297