1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h> 14 15 #include <executorch/backends/vulkan/runtime/utils/VecUtils.h> 16 17 #include <executorch/backends/vulkan/runtime/vk_api/Types.h> 18 19 #include <mutex> 20 #include <unordered_map> 21 22 namespace vkcompute { 23 namespace vkapi { 24 25 class ShaderLayout final { 26 public: 27 using Signature = std::vector<VkDescriptorType>; 28 29 explicit ShaderLayout(VkDevice, const Signature&); 30 31 ShaderLayout(const ShaderLayout&) = delete; 32 ShaderLayout& operator=(const ShaderLayout&) = delete; 33 34 ShaderLayout(ShaderLayout&&) noexcept; 35 ShaderLayout& operator=(ShaderLayout&&) = delete; 36 37 ~ShaderLayout(); 38 39 private: 40 VkDevice device_; 41 VkDescriptorSetLayout handle_; 42 43 public: handle()44 VkDescriptorSetLayout handle() const { 45 return handle_; 46 } 47 48 // We need to define a custom swap function since this class 49 // does not allow for move assignment. The swap function will 50 // be used in the hash map. 51 friend void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept; 52 }; 53 54 struct ShaderInfo final { 55 struct { 56 const uint32_t* bin = nullptr; 57 uint32_t size = 0u; 58 } src_code; 59 60 std::string kernel_name{""}; 61 ShaderLayout::Signature kernel_layout{}; 62 63 // Shader Metadata 64 utils::uvec3 out_tile_size{1u, 1u, 1u}; 65 66 explicit ShaderInfo(); 67 68 explicit ShaderInfo( 69 std::string, 70 const uint32_t*, 71 const uint32_t, 72 std::vector<VkDescriptorType>, 73 const utils::uvec3 tile_size); 74 75 operator bool() const { 76 return src_code.bin != nullptr; 77 }; 78 }; 79 80 bool operator==(const ShaderInfo& _1, const ShaderInfo& _2); 81 82 class ShaderModule final { 83 public: 84 explicit ShaderModule(VkDevice device, const ShaderInfo& source); 85 86 ShaderModule(const ShaderModule&) = delete; 87 ShaderModule& operator=(const ShaderModule&) = delete; 88 89 ShaderModule(ShaderModule&&) noexcept; 90 ShaderModule& operator=(ShaderModule&&) = delete; 91 92 ~ShaderModule(); 93 94 private: 95 VkDevice device_; 96 VkShaderModule handle_; 97 98 public: handle()99 inline VkShaderModule handle() const { 100 return handle_; 101 } 102 103 // We need to define a custom swap function since this class 104 // does not allow for move assignment. The swap function will 105 // be used in the hash map. 106 friend void swap(ShaderModule& lhs, ShaderModule& rhs) noexcept; 107 }; 108 109 class ShaderLayoutCache final { 110 public: 111 explicit ShaderLayoutCache(VkDevice device); 112 113 ShaderLayoutCache(const ShaderLayoutCache&) = delete; 114 ShaderLayoutCache& operator=(const ShaderLayoutCache&) = delete; 115 116 ShaderLayoutCache(ShaderLayoutCache&&) noexcept; 117 ShaderLayoutCache& operator=(ShaderLayoutCache&&) = delete; 118 119 ~ShaderLayoutCache(); 120 121 using Key = ShaderLayout::Signature; 122 using Value = ShaderLayout; 123 124 struct Hasher { operatorHasher125 inline size_t operator()(const ShaderLayout::Signature& signature) const { 126 size_t hashed = 0u; 127 128 for (const VkDescriptorType type : signature) { 129 hashed = 130 utils::hash_combine(hashed, std::hash<VkDescriptorType>()(type)); 131 } 132 133 return hashed; 134 } 135 }; 136 137 private: 138 // Multiple threads could potentially be adding entries into the cache, so use 139 // a mutex to manage access 140 std::mutex cache_mutex_; 141 142 VkDevice device_; 143 std::unordered_map<Key, Value, Hasher> cache_; 144 145 public: 146 VkDescriptorSetLayout retrieve(const Key&); 147 void purge(); 148 }; 149 150 class ShaderCache final { 151 public: 152 explicit ShaderCache(VkDevice device); 153 154 ShaderCache(const ShaderCache&) = delete; 155 ShaderCache& operator=(const ShaderCache&) = delete; 156 157 ShaderCache(ShaderCache&&) noexcept; 158 ShaderCache& operator=(ShaderCache&&) = delete; 159 160 ~ShaderCache(); 161 162 using Key = ShaderInfo; 163 using Value = ShaderModule; 164 165 struct Hasher { operatorHasher166 inline size_t operator()(const ShaderInfo& source) const { 167 size_t seed = 0; 168 seed = utils::hash_combine( 169 seed, std::hash<const uint32_t*>()(source.src_code.bin)); 170 seed = utils::hash_combine( 171 seed, std::hash<uint32_t>()(source.src_code.size)); 172 173 return seed; 174 } 175 }; 176 177 private: 178 // Multiple threads could potentially be adding entries into the cache, so use 179 // a mutex to manage access 180 std::mutex cache_mutex_; 181 182 VkDevice device_; 183 std::unordered_map<Key, Value, Hasher> cache_; 184 185 public: 186 VkShaderModule retrieve(const Key&); 187 void purge(); 188 }; 189 190 } // namespace vkapi 191 } // namespace vkcompute 192 193 inline bool operator==( 194 const VkDescriptorSetLayoutBinding& _1, 195 const VkDescriptorSetLayoutBinding& _2) { 196 return ( 197 _1.binding == _2.binding && _1.descriptorType == _2.descriptorType && 198 _1.descriptorCount == _2.descriptorCount && 199 _1.stageFlags == _2.stageFlags && 200 _1.pImmutableSamplers == _2.pImmutableSamplers); 201 } 202