xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/vk_api/Pipeline.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/vk_api/Pipeline.h>
10 
11 #include <fstream>
12 
13 namespace vkcompute {
14 namespace vkapi {
15 
16 //
17 // Utility Functions
18 //
19 
vk_access(const PipelineStageFlags stage,const MemoryAccessFlags access)20 VkAccessFlags vk_access(
21     const PipelineStageFlags stage,
22     const MemoryAccessFlags access) {
23   VkAccessFlags vk_access = 0u;
24 
25   if (access & MemoryAccessType::READ) {
26     if (stage & PipelineStage::COMPUTE) {
27       vk_access |= VK_ACCESS_SHADER_READ_BIT;
28     }
29 
30     if (stage & PipelineStage::HOST) {
31       vk_access |= VK_ACCESS_HOST_READ_BIT;
32     }
33 
34     if (stage & PipelineStage::TRANSFER) {
35       vk_access |= VK_ACCESS_TRANSFER_READ_BIT;
36     }
37   }
38 
39   if (access & MemoryAccessType::WRITE) {
40     if (stage & PipelineStage::COMPUTE) {
41       vk_access |= VK_ACCESS_SHADER_WRITE_BIT;
42     }
43 
44     if (stage & PipelineStage::HOST) {
45       vk_access |= VK_ACCESS_HOST_WRITE_BIT;
46     }
47 
48     if (stage & PipelineStage::TRANSFER) {
49       vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
50     }
51   }
52 
53   return vk_access;
54 }
55 
vk_stage(const PipelineStageFlags stage)56 VkPipelineStageFlags vk_stage(const PipelineStageFlags stage) {
57   VkPipelineStageFlags vk_stage = 0u;
58 
59   if (stage & PipelineStage::COMPUTE) {
60     vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
61   }
62 
63   if (stage & PipelineStage::HOST) {
64     vk_stage |= VK_PIPELINE_STAGE_HOST_BIT;
65   }
66 
67   if (stage & PipelineStage::TRANSFER) {
68     vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT;
69   }
70 
71   return vk_stage;
72 }
73 
vk_layout(const PipelineStageFlags stage,const MemoryAccessFlags access)74 VkImageLayout vk_layout(
75     const PipelineStageFlags stage,
76     const MemoryAccessFlags access) {
77   switch (stage) {
78     case PipelineStage::COMPUTE:
79       switch (access) {
80         case MemoryAccessType::READ:
81           return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
82         default:
83           return VK_IMAGE_LAYOUT_GENERAL;
84       }
85       break;
86     case PipelineStage::TRANSFER:
87       switch (access) {
88         case MemoryAccessType::READ:
89           return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL;
90         case MemoryAccessType::WRITE:
91           return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
92         default:
93           VK_THROW("Invalid memory access type for transfer stage!");
94       }
95       break;
96     default:
97       VK_THROW("Cannot determine appropriate image layout");
98   }
99 
100   return VK_IMAGE_LAYOUT_UNDEFINED;
101 }
102 
103 //
104 // SpecVar
105 //
106 
SpecVar()107 SpecVar::SpecVar() : type(SpecVar::Type::INT) {
108   value.as_int32 = 0;
109 }
110 
SpecVar(const float val)111 SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
112   value.as_float = val;
113 }
114 
SpecVar(const int32_t val)115 SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
116   value.as_int32 = val;
117 }
118 
SpecVar(const uint32_t val)119 SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
120   value.as_uint32 = val;
121 }
122 
SpecVar(const bool val)123 SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
124   value.as_bool = val;
125 }
126 
val_size() const127 uint32_t SpecVar::val_size() const {
128   switch (type) {
129     case SpecVar::Type::FLOAT:
130       return sizeof(float);
131     case SpecVar::Type::INT:
132       return sizeof(int32_t);
133     case SpecVar::Type::UINT:
134       return sizeof(uint32_t);
135     case SpecVar::Type::BOOL:
136       return sizeof(bool);
137   }
138   return 4;
139 }
140 
val_offset() const141 uint32_t SpecVar::val_offset() const {
142   return utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
143 }
144 
operator ==(const SpecVar & lhs,const SpecVar & rhs)145 bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
146   if (lhs.type != rhs.type) {
147     return false;
148   }
149   switch (lhs.type) {
150     case SpecVar::Type::FLOAT:
151       return lhs.value.as_float == rhs.value.as_float;
152     case SpecVar::Type::INT:
153       return lhs.value.as_int32 == rhs.value.as_int32;
154     case SpecVar::Type::UINT:
155       return lhs.value.as_uint32 == rhs.value.as_uint32;
156     case SpecVar::Type::BOOL:
157       return lhs.value.as_bool == rhs.value.as_bool;
158   }
159   return false;
160 }
161 
operator !=(const SpecVar & lhs,const SpecVar & rhs)162 bool operator!=(const SpecVar& lhs, const SpecVar& rhs) {
163   return !(lhs == rhs);
164 }
165 
SpecVarList()166 SpecVarList::SpecVarList() {}
167 
SpecVarList(std::initializer_list<SpecVar> init_list)168 SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
169   vars.resize(init_list.size());
170   std::copy(init_list.begin(), init_list.end(), vars.begin());
171 }
172 
append(const SpecVarList & other)173 void SpecVarList::append(const SpecVarList& other) {
174   vars.insert(vars.end(), other.vars.begin(), other.vars.end());
175 }
176 
generate_map_entries() const177 std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
178     const {
179   std::vector<VkSpecializationMapEntry> map_entries;
180   map_entries.resize(vars.size());
181   uint32_t cur_offset = 0u;
182   for (uint32_t i = 0; i < vars.size(); ++i) {
183     map_entries.at(i) = {
184         i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
185     cur_offset += sizeof(SpecVar);
186   }
187   return map_entries;
188 }
189 
operator ==(const SpecVarList & lhs,const SpecVarList & rhs)190 bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
191   if (lhs.size() != rhs.size()) {
192     return false;
193   }
194   for (uint32_t i = 0; i < lhs.size(); ++i) {
195     if (lhs.vars.at(i) != rhs.vars.at(i)) {
196       return false;
197     }
198   }
199   return true;
200 }
201 
202 //
203 // PipelineLayout
204 //
205 
PipelineLayout(VkDevice device,VkDescriptorSetLayout descriptor_layout)206 PipelineLayout::PipelineLayout(
207     VkDevice device,
208     VkDescriptorSetLayout descriptor_layout)
209     : device_(device), handle_{VK_NULL_HANDLE} {
210   // TODO: Enable push constants
211   const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
212       VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
213       nullptr, // pNext
214       0u, // flags
215       1u, // setLayoutCount
216       &descriptor_layout, // pSetLayouts
217       0u, // pushConstantRangeCount
218       nullptr, // pPushConstantRanges
219   };
220 
221   VK_CHECK(vkCreatePipelineLayout(
222       device_, &pipeline_layout_create_info, nullptr, &handle_));
223 }
224 
PipelineLayout(PipelineLayout && other)225 PipelineLayout::PipelineLayout(PipelineLayout&& other) noexcept
226     : device_(other.device_), handle_(other.handle_) {
227   other.handle_ = VK_NULL_HANDLE;
228 }
229 
~PipelineLayout()230 PipelineLayout::~PipelineLayout() {
231   if (handle_ == VK_NULL_HANDLE) {
232     return;
233   }
234   vkDestroyPipelineLayout(device_, handle_, nullptr);
235   handle_ = VK_NULL_HANDLE;
236 }
237 
swap(PipelineLayout & lhs,PipelineLayout & rhs)238 void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
239   VkDevice tmp_device = lhs.device_;
240   VkPipelineLayout tmp_handle = lhs.handle_;
241 
242   lhs.device_ = rhs.device_;
243   lhs.handle_ = rhs.handle_;
244 
245   rhs.device_ = tmp_device;
246   rhs.handle_ = tmp_handle;
247 }
248 
249 //
250 // ComputePipeline
251 //
252 
ComputePipeline(VkDevice device,const ComputePipeline::Descriptor & descriptor,VkPipelineCache pipeline_cache)253 ComputePipeline::ComputePipeline(
254     VkDevice device,
255     const ComputePipeline::Descriptor& descriptor,
256     VkPipelineCache pipeline_cache)
257     : device_(device), handle_{VK_NULL_HANDLE} {
258   std::vector<VkSpecializationMapEntry> map_entries =
259       descriptor.specialization_constants.generate_map_entries();
260 
261   const VkSpecializationInfo specialization_info{
262       descriptor.specialization_constants.size(), // mapEntryCount
263       map_entries.data(), // pMapEntries
264       descriptor.specialization_constants.data_nbytes(), // dataSize
265       descriptor.specialization_constants.data(), // pData
266   };
267 
268   const VkPipelineShaderStageCreateInfo shader_stage_create_info{
269       VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
270       nullptr, // pNext
271       0u, // flags
272       VK_SHADER_STAGE_COMPUTE_BIT, // stage
273       descriptor.shader_module, // module
274       "main", // pName
275       &specialization_info, // pSpecializationInfo
276   };
277 
278   const VkComputePipelineCreateInfo compute_pipeline_create_info{
279       VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
280       nullptr, // pNext
281       0u, // flags
282       shader_stage_create_info, // stage
283       descriptor.pipeline_layout, // layout
284       VK_NULL_HANDLE, // basePipelineHandle
285       0u, // basePipelineIndex
286   };
287 
288   VK_CHECK(vkCreateComputePipelines(
289       device_,
290       pipeline_cache,
291       1u,
292       &compute_pipeline_create_info,
293       nullptr,
294       &handle_));
295 }
296 
ComputePipeline(ComputePipeline && other)297 ComputePipeline::ComputePipeline(ComputePipeline&& other) noexcept
298     : device_(other.device_), handle_(other.handle_) {
299   other.handle_ = VK_NULL_HANDLE;
300 }
301 
~ComputePipeline()302 ComputePipeline::~ComputePipeline() {
303   if (handle_ == VK_NULL_HANDLE) {
304     return;
305   }
306   vkDestroyPipeline(device_, handle_, nullptr);
307   handle_ = VK_NULL_HANDLE;
308 }
309 
swap(ComputePipeline & lhs,ComputePipeline & rhs)310 void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept {
311   VkDevice tmp_device = lhs.device_;
312   VkPipeline tmp_handle = lhs.handle_;
313 
314   lhs.device_ = rhs.device_;
315   lhs.handle_ = rhs.handle_;
316 
317   rhs.device_ = tmp_device;
318   rhs.handle_ = tmp_handle;
319 }
320 
operator ==(const ComputePipeline::Descriptor & _1,const ComputePipeline::Descriptor & _2)321 bool operator==(
322     const ComputePipeline::Descriptor& _1,
323     const ComputePipeline::Descriptor& _2) {
324   return (
325       _1.pipeline_layout == _2.pipeline_layout &&
326       _1.shader_module == _2.shader_module &&
327       _1.specialization_constants == _2.specialization_constants);
328 }
329 
330 //
331 // PipelineLayoutCache
332 //
333 
PipelineLayoutCache(VkDevice device)334 PipelineLayoutCache::PipelineLayoutCache(VkDevice device)
335     : cache_mutex_{}, device_(device), cache_{} {}
336 
PipelineLayoutCache(PipelineLayoutCache && other)337 PipelineLayoutCache::PipelineLayoutCache(PipelineLayoutCache&& other) noexcept
338     : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
339   std::lock_guard<std::mutex> lock(other.cache_mutex_);
340 }
341 
~PipelineLayoutCache()342 PipelineLayoutCache::~PipelineLayoutCache() {
343   purge();
344 }
345 
retrieve(const PipelineLayoutCache::Key & key)346 VkPipelineLayout PipelineLayoutCache::retrieve(
347     const PipelineLayoutCache::Key& key) {
348   std::lock_guard<std::mutex> lock(cache_mutex_);
349 
350   auto it = cache_.find(key);
351   if (cache_.cend() == it) {
352     it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
353   }
354 
355   return it->second.handle();
356 }
357 
purge()358 void PipelineLayoutCache::purge() {
359   std::lock_guard<std::mutex> lock(cache_mutex_);
360   cache_.clear();
361 }
362 
363 //
364 // ComputePipelineCache
365 //
366 
ComputePipelineCache(VkDevice device,const std::string & cache_data_path)367 ComputePipelineCache::ComputePipelineCache(
368     VkDevice device,
369     const std::string& cache_data_path)
370     : cache_mutex_{},
371       device_(device),
372       pipeline_cache_{VK_NULL_HANDLE},
373       cache_{},
374       cache_data_path_(cache_data_path) {
375   VkPipelineCacheCreateInfo pipeline_cache_create_info{};
376 
377   auto buffer = load_cache();
378 
379   pipeline_cache_create_info = {
380       VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
381       nullptr, // pNext
382       0u, // flags
383       buffer.size(), // initialDataSize
384       buffer.data(), // pInitialData
385   };
386 
387   VK_CHECK(vkCreatePipelineCache(
388       device, &pipeline_cache_create_info, nullptr, &pipeline_cache_));
389 }
390 
ComputePipelineCache(ComputePipelineCache && other)391 ComputePipelineCache::ComputePipelineCache(
392     ComputePipelineCache&& other) noexcept
393     : cache_mutex_{},
394       device_(other.device_),
395       pipeline_cache_(other.pipeline_cache_),
396       cache_(std::move(other.cache_)) {
397   std::lock_guard<std::mutex> lock(other.cache_mutex_);
398 
399   other.pipeline_cache_ = VK_NULL_HANDLE;
400 }
401 
~ComputePipelineCache()402 ComputePipelineCache::~ComputePipelineCache() {
403   purge();
404 
405   if (pipeline_cache_ == VK_NULL_HANDLE) {
406     return;
407   }
408 
409   vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
410   pipeline_cache_ = VK_NULL_HANDLE;
411 }
412 
retrieve(const ComputePipelineCache::Key & key)413 VkPipeline ComputePipelineCache::retrieve(
414     const ComputePipelineCache::Key& key) {
415   std::lock_guard<std::mutex> lock(cache_mutex_);
416 
417   auto it = cache_.find(key);
418   if (cache_.cend() == it) {
419     it = cache_
420              .insert(
421                  {key,
422                   ComputePipelineCache::Value(device_, key, pipeline_cache_)})
423              .first;
424   }
425 
426   return it->second.handle();
427 }
428 
purge()429 void ComputePipelineCache::purge() {
430   cache_.clear();
431 }
432 
load_cache()433 std::vector<char> ComputePipelineCache::load_cache() {
434   // No optimization if path is unspecified
435   if (cache_data_path_.empty()) {
436     return {};
437   }
438 
439   // Return if file doesn't exist; this is expected on first model-load
440   std::ifstream file(cache_data_path_, std::ios::binary | std::ios::ate);
441   if (file.fail()) {
442     return {};
443   }
444 
445   auto size = file.tellg();
446   file.seekg(0, std::ios::beg);
447 
448   std::vector<char> buffer(size);
449   file.read(buffer.data(), size);
450 
451   return buffer;
452 }
453 
save_cache()454 void ComputePipelineCache::save_cache() {
455   // No optimization if path is unspecified
456   if (cache_data_path_.empty()) {
457     return;
458   }
459 
460   // Return if file exists; the cache is already saved
461   std::ifstream ifile(cache_data_path_);
462   if (ifile.good()) {
463     return;
464   }
465 
466   size_t size{};
467   vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr);
468 
469   std::vector<char> buffer(size);
470   vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data());
471 
472   std::ofstream file(cache_data_path_, std::ios::binary);
473   file.write(buffer.data(), buffer.size());
474 }
475 
476 } // namespace vkapi
477 } // namespace vkcompute
478