1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/api/Context.h> 14 15 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h> 16 17 #include <cstring> 18 19 namespace vkcompute { 20 namespace api { 21 22 class StagingBuffer final { 23 private: 24 Context* context_p_; 25 vkapi::ScalarType dtype_; 26 size_t numel_; 27 size_t nbytes_; 28 vkapi::VulkanBuffer vulkan_buffer_; 29 30 void* mapped_data_; 31 32 public: StagingBuffer(Context * context_p,const vkapi::ScalarType dtype,const size_t numel)33 StagingBuffer( 34 Context* context_p, 35 const vkapi::ScalarType dtype, 36 const size_t numel) 37 : context_p_(context_p), 38 dtype_(dtype), 39 numel_(numel), 40 nbytes_(element_size(dtype_) * numel_), 41 vulkan_buffer_( 42 context_p_->adapter_ptr()->vma().create_staging_buffer(nbytes_)), 43 mapped_data_(nullptr) {} 44 45 StagingBuffer(const StagingBuffer&) = delete; 46 StagingBuffer& operator=(const StagingBuffer&) = delete; 47 48 StagingBuffer(StagingBuffer&&) = default; 49 StagingBuffer& operator=(StagingBuffer&&) = default; 50 ~StagingBuffer()51 ~StagingBuffer() { 52 context_p_->register_buffer_cleanup(vulkan_buffer_); 53 } 54 dtype()55 inline vkapi::ScalarType dtype() { 56 return dtype_; 57 } 58 buffer()59 inline vkapi::VulkanBuffer& buffer() { 60 return vulkan_buffer_; 61 } 62 data()63 inline void* data() { 64 if (!mapped_data_) { 65 mapped_data_ = vulkan_buffer_.allocation_info().pMappedData; 66 } 67 return mapped_data_; 68 } 69 numel()70 inline size_t numel() { 71 return numel_; 72 } 73 nbytes()74 inline size_t nbytes() { 75 return nbytes_; 76 } 77 copy_from(const void * src,const size_t nbytes)78 inline void copy_from(const void* src, const size_t nbytes) { 79 VK_CHECK_COND(nbytes <= nbytes_); 80 memcpy(data(), src, nbytes); 81 vmaFlushAllocation( 82 vulkan_buffer_.vma_allocator(), 83 vulkan_buffer_.allocation(), 84 0u, 85 VK_WHOLE_SIZE); 86 } 87 copy_to(void * dst,const size_t nbytes)88 inline void copy_to(void* dst, const size_t nbytes) { 89 VK_CHECK_COND(nbytes <= nbytes_); 90 vmaInvalidateAllocation( 91 vulkan_buffer_.vma_allocator(), 92 vulkan_buffer_.allocation(), 93 0u, 94 VK_WHOLE_SIZE); 95 memcpy(dst, data(), nbytes); 96 } 97 set_staging_zeros()98 inline void set_staging_zeros() { 99 memset(data(), 0, nbytes_); 100 } 101 }; 102 103 } // namespace api 104 } // namespace vkcompute 105