xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/containers/StagingBuffer.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12 
13 #include <executorch/backends/vulkan/runtime/api/Context.h>
14 
15 #include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h>
16 
17 #include <cstring>
18 
19 namespace vkcompute {
20 namespace api {
21 
22 class StagingBuffer final {
23  private:
24   Context* context_p_;
25   vkapi::ScalarType dtype_;
26   size_t numel_;
27   size_t nbytes_;
28   vkapi::VulkanBuffer vulkan_buffer_;
29 
30   void* mapped_data_;
31 
32  public:
StagingBuffer(Context * context_p,const vkapi::ScalarType dtype,const size_t numel)33   StagingBuffer(
34       Context* context_p,
35       const vkapi::ScalarType dtype,
36       const size_t numel)
37       : context_p_(context_p),
38         dtype_(dtype),
39         numel_(numel),
40         nbytes_(element_size(dtype_) * numel_),
41         vulkan_buffer_(
42             context_p_->adapter_ptr()->vma().create_staging_buffer(nbytes_)),
43         mapped_data_(nullptr) {}
44 
45   StagingBuffer(const StagingBuffer&) = delete;
46   StagingBuffer& operator=(const StagingBuffer&) = delete;
47 
48   StagingBuffer(StagingBuffer&&) = default;
49   StagingBuffer& operator=(StagingBuffer&&) = default;
50 
~StagingBuffer()51   ~StagingBuffer() {
52     context_p_->register_buffer_cleanup(vulkan_buffer_);
53   }
54 
dtype()55   inline vkapi::ScalarType dtype() {
56     return dtype_;
57   }
58 
buffer()59   inline vkapi::VulkanBuffer& buffer() {
60     return vulkan_buffer_;
61   }
62 
data()63   inline void* data() {
64     if (!mapped_data_) {
65       mapped_data_ = vulkan_buffer_.allocation_info().pMappedData;
66     }
67     return mapped_data_;
68   }
69 
numel()70   inline size_t numel() {
71     return numel_;
72   }
73 
nbytes()74   inline size_t nbytes() {
75     return nbytes_;
76   }
77 
copy_from(const void * src,const size_t nbytes)78   inline void copy_from(const void* src, const size_t nbytes) {
79     VK_CHECK_COND(nbytes <= nbytes_);
80     memcpy(data(), src, nbytes);
81     vmaFlushAllocation(
82         vulkan_buffer_.vma_allocator(),
83         vulkan_buffer_.allocation(),
84         0u,
85         VK_WHOLE_SIZE);
86   }
87 
copy_to(void * dst,const size_t nbytes)88   inline void copy_to(void* dst, const size_t nbytes) {
89     VK_CHECK_COND(nbytes <= nbytes_);
90     vmaInvalidateAllocation(
91         vulkan_buffer_.vma_allocator(),
92         vulkan_buffer_.allocation(),
93         0u,
94         VK_WHOLE_SIZE);
95     memcpy(dst, data(), nbytes);
96   }
97 
set_staging_zeros()98   inline void set_staging_zeros() {
99     memset(data(), 0, nbytes_);
100   }
101 };
102 
103 } // namespace api
104 } // namespace vkcompute
105