xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/api/Context.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/api/Context.h>
10 
11 #ifndef VULKAN_DESCRIPTOR_POOL_SIZE
12 #define VULKAN_DESCRIPTOR_POOL_SIZE 1024u
13 #endif
14 
15 #ifndef VULKAN_QUERY_POOL_SIZE
16 #define VULKAN_QUERY_POOL_SIZE 4096u
17 #endif
18 
19 namespace vkcompute {
20 namespace api {
21 
Context(size_t adapter_i,const ContextConfig & config)22 Context::Context(size_t adapter_i, const ContextConfig& config)
23     : config_(config),
24       // Important handles
25       adapter_p_(vkapi::runtime()->get_adapter_p(adapter_i)),
26       device_(adapter_p_->device_handle()),
27       queue_(adapter_p_->request_queue()),
28       // Resource pools
29       command_pool_(device_, queue_.family_index, config_.cmd_pool_config),
30       descriptor_pool_(device_, config_.descriptor_pool_config),
31       fences_(device_),
32       // Profiling
33       querypool_(config_.query_pool_config, nullptr),
34       // Command buffer submission
35       cmd_mutex_{},
36       cmd_(VK_NULL_HANDLE, 0u),
37       submit_count_{0u},
38       // Memory Management
39       buffer_clearlist_mutex_{},
40       buffers_to_clear_{},
41       image_clearlist_mutex_{},
42       images_to_clear_{},
43       preferred_image_tiling_{VK_IMAGE_TILING_OPTIMAL} {
44   if (adapter_p_->linear_tiling_3d_enabled()) {
45     preferred_image_tiling_ = VK_IMAGE_TILING_LINEAR;
46   }
47 }
48 
~Context()49 Context::~Context() {
50   try {
51     flush();
52     // Let the device know the context is done with the queue
53     adapter_p_->return_queue(queue_);
54   } catch (...) {
55   }
56 }
57 
initialize_querypool()58 void Context::initialize_querypool() {
59   querypool_.initialize(adapter_p_);
60 }
61 
cmd_reset_querypool()62 void Context::cmd_reset_querypool() {
63   if (querypool_) {
64     set_cmd();
65     querypool_.reset_querypool(cmd_);
66   }
67 }
68 
report_shader_dispatch_start(const std::string & shader_name,const utils::uvec3 & global_wg_size,const utils::uvec3 & local_wg_size,const uint32_t dispatch_id)69 void Context::report_shader_dispatch_start(
70     const std::string& shader_name,
71     const utils::uvec3& global_wg_size,
72     const utils::uvec3& local_wg_size,
73     const uint32_t dispatch_id) {
74   if (querypool_) {
75     querypool_.shader_profile_begin(
76         cmd_,
77         dispatch_id,
78         shader_name,
79         vkapi::create_extent3d(global_wg_size),
80         vkapi::create_extent3d(local_wg_size));
81   }
82 }
83 
report_shader_dispatch_end()84 void Context::report_shader_dispatch_end() {
85   if (querypool_) {
86     querypool_.shader_profile_end(cmd_);
87   }
88 }
89 
get_descriptor_set(const vkapi::ShaderInfo & shader_descriptor,const utils::uvec3 & local_workgroup_size,const vkapi::SpecVarList & additional_constants)90 vkapi::DescriptorSet Context::get_descriptor_set(
91     const vkapi::ShaderInfo& shader_descriptor,
92     const utils::uvec3& local_workgroup_size,
93     const vkapi::SpecVarList& additional_constants) {
94   VkDescriptorSetLayout shader_layout =
95       shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
96 
97   VkPipelineLayout pipeline_layout =
98       pipeline_layout_cache().retrieve(shader_layout);
99 
100   vkapi::SpecVarList spec_constants = {
101       SV(local_workgroup_size[0u]),
102       SV(local_workgroup_size[1u]),
103       SV(local_workgroup_size[2u])};
104 
105   spec_constants.append(additional_constants);
106 
107   VkPipeline pipeline = pipeline_cache().retrieve(
108       {pipeline_layout_cache().retrieve(shader_layout),
109        shader_cache().retrieve(shader_descriptor),
110        spec_constants});
111 
112   cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);
113 
114   return descriptor_pool().get_descriptor_set(
115       shader_layout, shader_descriptor.kernel_layout);
116 }
117 
register_shader_dispatch(const vkapi::DescriptorSet & descriptors,vkapi::PipelineBarrier & pipeline_barrier,const vkapi::ShaderInfo & shader_descriptor,const utils::uvec3 & global_workgroup_size)118 void Context::register_shader_dispatch(
119     const vkapi::DescriptorSet& descriptors,
120     vkapi::PipelineBarrier& pipeline_barrier,
121     const vkapi::ShaderInfo& shader_descriptor,
122     const utils::uvec3& global_workgroup_size) {
123   // Adjust the global workgroup size based on the output tile size
124   uint32_t global_wg_w = utils::div_up(
125       global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
126   uint32_t global_wg_h = utils::div_up(
127       global_workgroup_size[1u], shader_descriptor.out_tile_size[1u]);
128   uint32_t global_wg_d = utils::div_up(
129       global_workgroup_size[2u], shader_descriptor.out_tile_size[2u]);
130 
131   // Submitting a global work group size of 0 is undefined behaviour. If this is
132   // detected then submit a single workgroup instead.
133   if (global_wg_w == 0u || global_wg_h == 0u || global_wg_d == 0u) {
134     global_wg_w = 1u;
135     global_wg_h = 1u;
136     global_wg_d = 1u;
137   }
138 
139   const utils::uvec3 effective_global_wg = {
140       global_wg_w,
141       global_wg_h,
142       global_wg_d,
143   };
144 
145   cmd_.bind_descriptors(descriptors.get_bind_handle());
146   cmd_.insert_barrier(pipeline_barrier);
147 
148   cmd_.dispatch(effective_global_wg);
149 }
150 
register_blit(vkapi::PipelineBarrier & pipeline_barrier,vkapi::VulkanImage & src,vkapi::VulkanImage & dst)151 void Context::register_blit(
152     vkapi::PipelineBarrier& pipeline_barrier,
153     vkapi::VulkanImage& src,
154     vkapi::VulkanImage& dst) {
155   cmd_.insert_barrier(pipeline_barrier);
156   cmd_.blit(src, dst);
157 }
158 
submit_cmd_to_gpu(VkFence fence_handle,const bool final_use)159 void Context::submit_cmd_to_gpu(VkFence fence_handle, const bool final_use) {
160   if (cmd_) {
161     cmd_.end();
162     adapter_p_->submit_cmd(
163         queue_, cmd_.get_submit_handle(final_use), fence_handle);
164 
165     submit_count_ = 0u;
166   }
167 }
168 
flush()169 void Context::flush() {
170   VK_CHECK(vkQueueWaitIdle(queue()));
171 
172   command_pool_.flush();
173   descriptor_pool_.flush();
174 
175   // If there is an existing command buffer, invalidate it
176   if (cmd_) {
177     cmd_.invalidate();
178   }
179 
180   std::lock_guard<std::mutex> bufferlist_lock(buffer_clearlist_mutex_);
181   std::lock_guard<std::mutex> imagelist_lock(image_clearlist_mutex_);
182   buffers_to_clear_.clear();
183   images_to_clear_.clear();
184 }
185 
available()186 bool available() {
187   return context();
188 }
189 
context()190 Context* context() {
191   static const std::unique_ptr<Context> context([]() -> Context* {
192     try {
193       const uint32_t cmd_submit_frequency = 16u;
194 
195       const vkapi::CommandPoolConfig cmd_config{
196           32u, // cmdPoolInitialSize
197           8u, // cmdPoolBatchSize
198       };
199 
200       const vkapi::DescriptorPoolConfig descriptor_pool_config{
201           VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorPoolMaxSets
202           VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorUniformBufferCount
203           VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageBufferCount
204           VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorCombinedSamplerCount
205           VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageImageCount
206           32u, // descriptorPileSizes
207       };
208 
209       const vkapi::QueryPoolConfig query_pool_config{
210           VULKAN_QUERY_POOL_SIZE, // maxQueryCount
211           256u, // initialReserveSize
212       };
213 
214       const ContextConfig config{
215           cmd_submit_frequency,
216           cmd_config,
217           descriptor_pool_config,
218           query_pool_config,
219       };
220 
221       return new Context(vkapi::runtime()->default_adapter_i(), config);
222     } catch (...) {
223     }
224 
225     return nullptr;
226   }());
227 
228   return context.get();
229 }
230 
231 } // namespace api
232 } // namespace vkcompute
233