1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/vk_api/Pipeline.h>
10
11 #include <fstream>
12
13 namespace vkcompute {
14 namespace vkapi {
15
16 //
17 // Utility Functions
18 //
19
vk_access(const PipelineStageFlags stage,const MemoryAccessFlags access)20 VkAccessFlags vk_access(
21 const PipelineStageFlags stage,
22 const MemoryAccessFlags access) {
23 VkAccessFlags vk_access = 0u;
24
25 if (access & MemoryAccessType::READ) {
26 if (stage & PipelineStage::COMPUTE) {
27 vk_access |= VK_ACCESS_SHADER_READ_BIT;
28 }
29
30 if (stage & PipelineStage::HOST) {
31 vk_access |= VK_ACCESS_HOST_READ_BIT;
32 }
33
34 if (stage & PipelineStage::TRANSFER) {
35 vk_access |= VK_ACCESS_TRANSFER_READ_BIT;
36 }
37 }
38
39 if (access & MemoryAccessType::WRITE) {
40 if (stage & PipelineStage::COMPUTE) {
41 vk_access |= VK_ACCESS_SHADER_WRITE_BIT;
42 }
43
44 if (stage & PipelineStage::HOST) {
45 vk_access |= VK_ACCESS_HOST_WRITE_BIT;
46 }
47
48 if (stage & PipelineStage::TRANSFER) {
49 vk_access |= VK_ACCESS_TRANSFER_WRITE_BIT;
50 }
51 }
52
53 return vk_access;
54 }
55
vk_stage(const PipelineStageFlags stage)56 VkPipelineStageFlags vk_stage(const PipelineStageFlags stage) {
57 VkPipelineStageFlags vk_stage = 0u;
58
59 if (stage & PipelineStage::COMPUTE) {
60 vk_stage |= VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
61 }
62
63 if (stage & PipelineStage::HOST) {
64 vk_stage |= VK_PIPELINE_STAGE_HOST_BIT;
65 }
66
67 if (stage & PipelineStage::TRANSFER) {
68 vk_stage |= VK_PIPELINE_STAGE_TRANSFER_BIT;
69 }
70
71 return vk_stage;
72 }
73
vk_layout(const PipelineStageFlags stage,const MemoryAccessFlags access)74 VkImageLayout vk_layout(
75 const PipelineStageFlags stage,
76 const MemoryAccessFlags access) {
77 switch (stage) {
78 case PipelineStage::COMPUTE:
79 switch (access) {
80 case MemoryAccessType::READ:
81 return VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL;
82 default:
83 return VK_IMAGE_LAYOUT_GENERAL;
84 }
85 break;
86 case PipelineStage::TRANSFER:
87 switch (access) {
88 case MemoryAccessType::READ:
89 return VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL;
90 case MemoryAccessType::WRITE:
91 return VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL;
92 default:
93 VK_THROW("Invalid memory access type for transfer stage!");
94 }
95 break;
96 default:
97 VK_THROW("Cannot determine appropriate image layout");
98 }
99
100 return VK_IMAGE_LAYOUT_UNDEFINED;
101 }
102
103 //
104 // SpecVar
105 //
106
SpecVar()107 SpecVar::SpecVar() : type(SpecVar::Type::INT) {
108 value.as_int32 = 0;
109 }
110
SpecVar(const float val)111 SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) {
112 value.as_float = val;
113 }
114
SpecVar(const int32_t val)115 SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) {
116 value.as_int32 = val;
117 }
118
SpecVar(const uint32_t val)119 SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) {
120 value.as_uint32 = val;
121 }
122
SpecVar(const bool val)123 SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) {
124 value.as_bool = val;
125 }
126
val_size() const127 uint32_t SpecVar::val_size() const {
128 switch (type) {
129 case SpecVar::Type::FLOAT:
130 return sizeof(float);
131 case SpecVar::Type::INT:
132 return sizeof(int32_t);
133 case SpecVar::Type::UINT:
134 return sizeof(uint32_t);
135 case SpecVar::Type::BOOL:
136 return sizeof(bool);
137 }
138 return 4;
139 }
140
val_offset() const141 uint32_t SpecVar::val_offset() const {
142 return utils::safe_downcast<uint32_t>(offsetof(SpecVar, value));
143 }
144
operator ==(const SpecVar & lhs,const SpecVar & rhs)145 bool operator==(const SpecVar& lhs, const SpecVar& rhs) {
146 if (lhs.type != rhs.type) {
147 return false;
148 }
149 switch (lhs.type) {
150 case SpecVar::Type::FLOAT:
151 return lhs.value.as_float == rhs.value.as_float;
152 case SpecVar::Type::INT:
153 return lhs.value.as_int32 == rhs.value.as_int32;
154 case SpecVar::Type::UINT:
155 return lhs.value.as_uint32 == rhs.value.as_uint32;
156 case SpecVar::Type::BOOL:
157 return lhs.value.as_bool == rhs.value.as_bool;
158 }
159 return false;
160 }
161
operator !=(const SpecVar & lhs,const SpecVar & rhs)162 bool operator!=(const SpecVar& lhs, const SpecVar& rhs) {
163 return !(lhs == rhs);
164 }
165
SpecVarList()166 SpecVarList::SpecVarList() {}
167
SpecVarList(std::initializer_list<SpecVar> init_list)168 SpecVarList::SpecVarList(std::initializer_list<SpecVar> init_list) {
169 vars.resize(init_list.size());
170 std::copy(init_list.begin(), init_list.end(), vars.begin());
171 }
172
append(const SpecVarList & other)173 void SpecVarList::append(const SpecVarList& other) {
174 vars.insert(vars.end(), other.vars.begin(), other.vars.end());
175 }
176
generate_map_entries() const177 std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
178 const {
179 std::vector<VkSpecializationMapEntry> map_entries;
180 map_entries.resize(vars.size());
181 uint32_t cur_offset = 0u;
182 for (uint32_t i = 0; i < vars.size(); ++i) {
183 map_entries.at(i) = {
184 i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()};
185 cur_offset += sizeof(SpecVar);
186 }
187 return map_entries;
188 }
189
operator ==(const SpecVarList & lhs,const SpecVarList & rhs)190 bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) {
191 if (lhs.size() != rhs.size()) {
192 return false;
193 }
194 for (uint32_t i = 0; i < lhs.size(); ++i) {
195 if (lhs.vars.at(i) != rhs.vars.at(i)) {
196 return false;
197 }
198 }
199 return true;
200 }
201
202 //
203 // PipelineLayout
204 //
205
PipelineLayout(VkDevice device,VkDescriptorSetLayout descriptor_layout)206 PipelineLayout::PipelineLayout(
207 VkDevice device,
208 VkDescriptorSetLayout descriptor_layout)
209 : device_(device), handle_{VK_NULL_HANDLE} {
210 // TODO: Enable push constants
211 const VkPipelineLayoutCreateInfo pipeline_layout_create_info{
212 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, // sType
213 nullptr, // pNext
214 0u, // flags
215 1u, // setLayoutCount
216 &descriptor_layout, // pSetLayouts
217 0u, // pushConstantRangeCount
218 nullptr, // pPushConstantRanges
219 };
220
221 VK_CHECK(vkCreatePipelineLayout(
222 device_, &pipeline_layout_create_info, nullptr, &handle_));
223 }
224
PipelineLayout(PipelineLayout && other)225 PipelineLayout::PipelineLayout(PipelineLayout&& other) noexcept
226 : device_(other.device_), handle_(other.handle_) {
227 other.handle_ = VK_NULL_HANDLE;
228 }
229
~PipelineLayout()230 PipelineLayout::~PipelineLayout() {
231 if (handle_ == VK_NULL_HANDLE) {
232 return;
233 }
234 vkDestroyPipelineLayout(device_, handle_, nullptr);
235 handle_ = VK_NULL_HANDLE;
236 }
237
swap(PipelineLayout & lhs,PipelineLayout & rhs)238 void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
239 VkDevice tmp_device = lhs.device_;
240 VkPipelineLayout tmp_handle = lhs.handle_;
241
242 lhs.device_ = rhs.device_;
243 lhs.handle_ = rhs.handle_;
244
245 rhs.device_ = tmp_device;
246 rhs.handle_ = tmp_handle;
247 }
248
249 //
250 // ComputePipeline
251 //
252
ComputePipeline(VkDevice device,const ComputePipeline::Descriptor & descriptor,VkPipelineCache pipeline_cache)253 ComputePipeline::ComputePipeline(
254 VkDevice device,
255 const ComputePipeline::Descriptor& descriptor,
256 VkPipelineCache pipeline_cache)
257 : device_(device), handle_{VK_NULL_HANDLE} {
258 std::vector<VkSpecializationMapEntry> map_entries =
259 descriptor.specialization_constants.generate_map_entries();
260
261 const VkSpecializationInfo specialization_info{
262 descriptor.specialization_constants.size(), // mapEntryCount
263 map_entries.data(), // pMapEntries
264 descriptor.specialization_constants.data_nbytes(), // dataSize
265 descriptor.specialization_constants.data(), // pData
266 };
267
268 const VkPipelineShaderStageCreateInfo shader_stage_create_info{
269 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
270 nullptr, // pNext
271 0u, // flags
272 VK_SHADER_STAGE_COMPUTE_BIT, // stage
273 descriptor.shader_module, // module
274 "main", // pName
275 &specialization_info, // pSpecializationInfo
276 };
277
278 const VkComputePipelineCreateInfo compute_pipeline_create_info{
279 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
280 nullptr, // pNext
281 0u, // flags
282 shader_stage_create_info, // stage
283 descriptor.pipeline_layout, // layout
284 VK_NULL_HANDLE, // basePipelineHandle
285 0u, // basePipelineIndex
286 };
287
288 VK_CHECK(vkCreateComputePipelines(
289 device_,
290 pipeline_cache,
291 1u,
292 &compute_pipeline_create_info,
293 nullptr,
294 &handle_));
295 }
296
ComputePipeline(ComputePipeline && other)297 ComputePipeline::ComputePipeline(ComputePipeline&& other) noexcept
298 : device_(other.device_), handle_(other.handle_) {
299 other.handle_ = VK_NULL_HANDLE;
300 }
301
~ComputePipeline()302 ComputePipeline::~ComputePipeline() {
303 if (handle_ == VK_NULL_HANDLE) {
304 return;
305 }
306 vkDestroyPipeline(device_, handle_, nullptr);
307 handle_ = VK_NULL_HANDLE;
308 }
309
swap(ComputePipeline & lhs,ComputePipeline & rhs)310 void swap(ComputePipeline& lhs, ComputePipeline& rhs) noexcept {
311 VkDevice tmp_device = lhs.device_;
312 VkPipeline tmp_handle = lhs.handle_;
313
314 lhs.device_ = rhs.device_;
315 lhs.handle_ = rhs.handle_;
316
317 rhs.device_ = tmp_device;
318 rhs.handle_ = tmp_handle;
319 }
320
operator ==(const ComputePipeline::Descriptor & _1,const ComputePipeline::Descriptor & _2)321 bool operator==(
322 const ComputePipeline::Descriptor& _1,
323 const ComputePipeline::Descriptor& _2) {
324 return (
325 _1.pipeline_layout == _2.pipeline_layout &&
326 _1.shader_module == _2.shader_module &&
327 _1.specialization_constants == _2.specialization_constants);
328 }
329
330 //
331 // PipelineLayoutCache
332 //
333
PipelineLayoutCache(VkDevice device)334 PipelineLayoutCache::PipelineLayoutCache(VkDevice device)
335 : cache_mutex_{}, device_(device), cache_{} {}
336
PipelineLayoutCache(PipelineLayoutCache && other)337 PipelineLayoutCache::PipelineLayoutCache(PipelineLayoutCache&& other) noexcept
338 : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) {
339 std::lock_guard<std::mutex> lock(other.cache_mutex_);
340 }
341
~PipelineLayoutCache()342 PipelineLayoutCache::~PipelineLayoutCache() {
343 purge();
344 }
345
retrieve(const PipelineLayoutCache::Key & key)346 VkPipelineLayout PipelineLayoutCache::retrieve(
347 const PipelineLayoutCache::Key& key) {
348 std::lock_guard<std::mutex> lock(cache_mutex_);
349
350 auto it = cache_.find(key);
351 if (cache_.cend() == it) {
352 it = cache_.insert({key, PipelineLayoutCache::Value(device_, key)}).first;
353 }
354
355 return it->second.handle();
356 }
357
purge()358 void PipelineLayoutCache::purge() {
359 std::lock_guard<std::mutex> lock(cache_mutex_);
360 cache_.clear();
361 }
362
363 //
364 // ComputePipelineCache
365 //
366
ComputePipelineCache(VkDevice device,const std::string & cache_data_path)367 ComputePipelineCache::ComputePipelineCache(
368 VkDevice device,
369 const std::string& cache_data_path)
370 : cache_mutex_{},
371 device_(device),
372 pipeline_cache_{VK_NULL_HANDLE},
373 cache_{},
374 cache_data_path_(cache_data_path) {
375 VkPipelineCacheCreateInfo pipeline_cache_create_info{};
376
377 auto buffer = load_cache();
378
379 pipeline_cache_create_info = {
380 VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType
381 nullptr, // pNext
382 0u, // flags
383 buffer.size(), // initialDataSize
384 buffer.data(), // pInitialData
385 };
386
387 VK_CHECK(vkCreatePipelineCache(
388 device, &pipeline_cache_create_info, nullptr, &pipeline_cache_));
389 }
390
ComputePipelineCache(ComputePipelineCache && other)391 ComputePipelineCache::ComputePipelineCache(
392 ComputePipelineCache&& other) noexcept
393 : cache_mutex_{},
394 device_(other.device_),
395 pipeline_cache_(other.pipeline_cache_),
396 cache_(std::move(other.cache_)) {
397 std::lock_guard<std::mutex> lock(other.cache_mutex_);
398
399 other.pipeline_cache_ = VK_NULL_HANDLE;
400 }
401
~ComputePipelineCache()402 ComputePipelineCache::~ComputePipelineCache() {
403 purge();
404
405 if (pipeline_cache_ == VK_NULL_HANDLE) {
406 return;
407 }
408
409 vkDestroyPipelineCache(device_, pipeline_cache_, nullptr);
410 pipeline_cache_ = VK_NULL_HANDLE;
411 }
412
retrieve(const ComputePipelineCache::Key & key)413 VkPipeline ComputePipelineCache::retrieve(
414 const ComputePipelineCache::Key& key) {
415 std::lock_guard<std::mutex> lock(cache_mutex_);
416
417 auto it = cache_.find(key);
418 if (cache_.cend() == it) {
419 it = cache_
420 .insert(
421 {key,
422 ComputePipelineCache::Value(device_, key, pipeline_cache_)})
423 .first;
424 }
425
426 return it->second.handle();
427 }
428
purge()429 void ComputePipelineCache::purge() {
430 cache_.clear();
431 }
432
load_cache()433 std::vector<char> ComputePipelineCache::load_cache() {
434 // No optimization if path is unspecified
435 if (cache_data_path_.empty()) {
436 return {};
437 }
438
439 // Return if file doesn't exist; this is expected on first model-load
440 std::ifstream file(cache_data_path_, std::ios::binary | std::ios::ate);
441 if (file.fail()) {
442 return {};
443 }
444
445 auto size = file.tellg();
446 file.seekg(0, std::ios::beg);
447
448 std::vector<char> buffer(size);
449 file.read(buffer.data(), size);
450
451 return buffer;
452 }
453
save_cache()454 void ComputePipelineCache::save_cache() {
455 // No optimization if path is unspecified
456 if (cache_data_path_.empty()) {
457 return;
458 }
459
460 // Return if file exists; the cache is already saved
461 std::ifstream ifile(cache_data_path_);
462 if (ifile.good()) {
463 return;
464 }
465
466 size_t size{};
467 vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr);
468
469 std::vector<char> buffer(size);
470 vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data());
471
472 std::ofstream file(cache_data_path_, std::ios::binary);
473 file.write(buffer.data(), buffer.size());
474 }
475
476 } // namespace vkapi
477 } // namespace vkcompute
478