1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h> 14 15 #include <executorch/backends/vulkan/runtime/vk_api/Shader.h> 16 17 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h> 18 #include <executorch/backends/vulkan/runtime/vk_api/memory/Image.h> 19 20 #include <mutex> 21 #include <unordered_map> 22 23 #define SV(x) ::vkcompute::vkapi::SpecVar(x) 24 25 namespace vkcompute { 26 namespace vkapi { 27 28 struct SpecVar final { 29 enum class Type : uint8_t { 30 FLOAT, 31 INT, 32 UINT, 33 BOOL, 34 }; 35 36 union Value { 37 int32_t as_int32; 38 uint32_t as_uint32; 39 float as_float; 40 bool as_bool; 41 }; 42 43 Value value; 44 Type type; 45 46 SpecVar(); 47 SpecVar(const float val); 48 SpecVar(const int32_t val); 49 SpecVar(const uint32_t val); 50 SpecVar(const bool val); 51 52 uint32_t val_size() const; 53 uint32_t val_offset() const; 54 }; 55 56 bool operator==(const SpecVar& lhs, const SpecVar& rhs); 57 58 bool operator!=(const SpecVar& lhs, const SpecVar& rhs); 59 60 class SpecVarList final { 61 std::vector<SpecVar> vars; 62 63 public: 64 SpecVarList(); 65 SpecVarList(std::initializer_list<SpecVar> init_list); 66 at(const size_t index)67 inline const SpecVar& at(const size_t index) const { 68 return vars.at(index); 69 } 70 data()71 inline const SpecVar* data() const { 72 return vars.data(); 73 } 74 size()75 inline uint32_t size() const { 76 return utils::safe_downcast<uint32_t>(vars.size()); 77 } 78 data_nbytes()79 inline uint32_t data_nbytes() const { 80 return vars.size() * sizeof(SpecVar); 81 } 82 83 void append(const SpecVarList& other); 84 85 std::vector<VkSpecializationMapEntry> generate_map_entries() const; 86 87 friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); 88 }; 89 90 bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); 91 92 struct PipelineBarrier final { 93 struct Stages final { 94 VkPipelineStageFlags src; 95 VkPipelineStageFlags dst; 96 } stage; 97 98 std::vector<BufferMemoryBarrier> buffers; 99 std::vector<ImageMemoryBarrier> images; 100 std::vector<VkBufferMemoryBarrier> buffer_barrier_handles; 101 std::vector<VkImageMemoryBarrier> image_barrier_handles; 102 103 inline operator bool() const { 104 return (0u != stage.src) || (0u != stage.dst) || !buffers.empty() || 105 !images.empty(); 106 } 107 }; 108 109 using PipelineStageFlags = uint8_t; 110 111 enum PipelineStage : PipelineStageFlags { 112 NO_STAGE = 0u << 0u, 113 COMPUTE = 1u << 0u, 114 HOST = 1u << 1u, 115 TRANSFER = 1u << 2u, 116 }; 117 118 VkAccessFlags vk_access(const PipelineStageFlags, const MemoryAccessFlags); 119 VkPipelineStageFlags vk_stage(const PipelineStageFlags); 120 VkImageLayout vk_layout(const PipelineStageFlags, const MemoryAccessFlags); 121 122 class PipelineLayout final { 123 public: 124 explicit PipelineLayout(VkDevice, VkDescriptorSetLayout); 125 126 PipelineLayout(const PipelineLayout&) = delete; 127 PipelineLayout& operator=(const PipelineLayout&) = delete; 128 129 PipelineLayout(PipelineLayout&&) noexcept; 130 PipelineLayout& operator=(PipelineLayout&&) = delete; 131 132 ~PipelineLayout(); 133 134 private: 135 VkDevice device_; 136 VkPipelineLayout handle_; 137 138 public: handle()139 VkPipelineLayout handle() const { 140 return handle_; 141 } 142 143 // We need to define a custom swap function since this class 144 // does not allow for move assignment. The swap function will 145 // be used in the hash map. 146 friend void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept; 147 }; 148 149 class ComputePipeline final { 150 public: 151 struct Descriptor final { 152 VkPipelineLayout pipeline_layout; 153 VkShaderModule shader_module; 154 SpecVarList specialization_constants; 155 }; 156 157 explicit ComputePipeline( 158 VkDevice device, 159 const Descriptor& descriptor, 160 VkPipelineCache pipeline_cache); 161 162 ComputePipeline(const ComputePipeline&) = delete; 163 ComputePipeline& operator=(const ComputePipeline&) = delete; 164 165 ComputePipeline(ComputePipeline&&) noexcept; 166 ComputePipeline& operator=(ComputePipeline&&) = delete; 167 168 ~ComputePipeline(); 169 170 private: 171 VkDevice device_; 172 VkPipeline handle_; 173 174 public: handle()175 inline VkPipeline handle() const { 176 return handle_; 177 } 178 179 // We need to define a custom swap function since this class 180 // does not allow for move assignment. The swap function will 181 // be used in the hash map. 182 friend void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept; 183 }; 184 185 class PipelineLayoutCache final { 186 public: 187 explicit PipelineLayoutCache(VkDevice device); 188 189 PipelineLayoutCache(const PipelineLayoutCache&) = delete; 190 PipelineLayoutCache& operator=(const PipelineLayoutCache&) = delete; 191 192 PipelineLayoutCache(PipelineLayoutCache&&) noexcept; 193 PipelineLayoutCache& operator=(PipelineLayoutCache&&) = delete; 194 195 ~PipelineLayoutCache(); 196 197 using Key = VkDescriptorSetLayout; 198 using Value = PipelineLayout; 199 200 struct Hasher { operatorHasher201 inline size_t operator()(VkDescriptorSetLayout descriptor_layout) const { 202 return std::hash<VkDescriptorSetLayout>()(descriptor_layout); 203 } 204 }; 205 206 private: 207 // Multiple threads could potentially be adding entries into the cache, so use 208 // a mutex to manage access 209 std::mutex cache_mutex_; 210 211 VkDevice device_; 212 std::unordered_map<Key, Value, Hasher> cache_; 213 214 public: 215 VkPipelineLayout retrieve(const Key&); 216 void purge(); 217 }; 218 219 class ComputePipelineCache final { 220 public: 221 explicit ComputePipelineCache( 222 VkDevice device, 223 const std::string& cache_data_path); 224 225 ComputePipelineCache(const ComputePipelineCache&) = delete; 226 ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; 227 228 ComputePipelineCache(ComputePipelineCache&&) noexcept; 229 ComputePipelineCache& operator=(ComputePipelineCache&&) = delete; 230 231 ~ComputePipelineCache(); 232 233 using Key = ComputePipeline::Descriptor; 234 using Value = ComputePipeline; 235 236 struct Hasher { operatorHasher237 inline size_t operator()( 238 const ComputePipeline::Descriptor& descriptor) const { 239 size_t seed = 0; 240 seed = utils::hash_combine( 241 seed, std::hash<VkPipelineLayout>()(descriptor.pipeline_layout)); 242 seed = utils::hash_combine( 243 seed, std::hash<VkShaderModule>()(descriptor.shader_module)); 244 245 const SpecVarList& spec_vars = descriptor.specialization_constants; 246 seed = utils::hash_combine(seed, std::hash<uint32_t>()(spec_vars.size())); 247 248 for (int i = 0; i < spec_vars.size(); ++i) { 249 const SpecVar& spec_var = spec_vars.at(i); 250 size_t new_seed = 0; 251 switch (spec_var.type) { 252 case SpecVar::Type::FLOAT: 253 new_seed = std::hash<float>()(spec_var.value.as_float); 254 break; 255 case SpecVar::Type::INT: 256 new_seed = std::hash<int32_t>()(spec_var.value.as_int32); 257 break; 258 case SpecVar::Type::UINT: 259 new_seed = std::hash<uint32_t>()(spec_var.value.as_uint32); 260 break; 261 case SpecVar::Type::BOOL: 262 new_seed = std::hash<bool>()(spec_var.value.as_bool); 263 break; 264 } 265 seed = utils::hash_combine(seed, new_seed); 266 } 267 268 return seed; 269 } 270 }; 271 272 void save_cache(); 273 274 private: 275 std::vector<char> load_cache(); 276 277 // Multiple threads could potentially be adding entries into the cache, so use 278 // a mutex to manage access 279 std::mutex cache_mutex_; 280 281 VkDevice device_; 282 VkPipelineCache pipeline_cache_; 283 std::unordered_map<Key, Value, Hasher> cache_; 284 const std::string cache_data_path_; 285 286 public: 287 VkPipeline retrieve(const Key&); 288 void purge(); 289 }; 290 291 // 292 // Impl 293 // 294 295 } // namespace vkapi 296 } // namespace vkcompute 297