1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
12
13 #include <executorch/backends/vulkan/runtime/utils/MacroUtils.h>
14
15 #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
16 #include <executorch/backends/vulkan/runtime/vk_api/Command.h>
17 #include <executorch/backends/vulkan/runtime/vk_api/Descriptor.h>
18 #include <executorch/backends/vulkan/runtime/vk_api/Fence.h>
19 #include <executorch/backends/vulkan/runtime/vk_api/QueryPool.h>
20 #include <executorch/backends/vulkan/runtime/vk_api/Runtime.h>
21 #include <executorch/backends/vulkan/runtime/vk_api/VkUtils.h>
22
23 namespace vkcompute {
24 namespace api {
25
26 struct ContextConfig final {
27 uint32_t cmd_submit_frequency;
28 vkapi::CommandPoolConfig cmd_pool_config;
29 vkapi::DescriptorPoolConfig descriptor_pool_config;
30 vkapi::QueryPoolConfig query_pool_config;
31 };
32
33 //
34 // Vulkan Context holds onto all relevant Vulkan state as it pertains to our
35 // use of Vulkan in PyTorch. A Context is associated with one, and only one,
36 // Adapter as a precursor to multi-GPU support. All Vulkan tensors in PyTorch
37 // are associated with a Context to make tensor <-> device affinity explicit.
38 // The context is currently a global object, but technically it does not need
39 // to be if we were to make it explicit to the user.
40 //
41
42 class Context final {
43 public:
44 explicit Context(size_t adapter_i, const ContextConfig&);
45
46 Context(const Context&) = delete;
47 Context& operator=(const Context&) = delete;
48
49 Context(Context&&) = delete;
50 Context& operator=(Context&&) = delete;
51
52 ~Context();
53
54 private:
55 // Config
56 ContextConfig config_;
57 // Important handles
58 vkapi::Adapter* adapter_p_;
59 VkDevice device_;
60 vkapi::Adapter::Queue queue_;
61 // Resource Pools
62 vkapi::CommandPool command_pool_;
63 vkapi::DescriptorPool descriptor_pool_;
64 vkapi::FencePool fences_;
65 // Diagnostics
66 vkapi::QueryPool querypool_;
67 // Command buffers submission
68 std::mutex cmd_mutex_;
69 vkapi::CommandBuffer cmd_;
70 uint32_t submit_count_;
71 // Memory Management
72 std::mutex buffer_clearlist_mutex_;
73 std::vector<vkapi::VulkanBuffer> buffers_to_clear_;
74 std::mutex image_clearlist_mutex_;
75 std::vector<vkapi::VulkanImage> images_to_clear_;
76 // Misc
77 VkImageTiling preferred_image_tiling_;
78
79 public:
80 // Adapter access
81
adapter_ptr()82 inline vkapi::Adapter* adapter_ptr() {
83 return adapter_p_;
84 }
85
device()86 inline VkDevice device() {
87 return device_;
88 }
89
queue()90 inline VkQueue queue() {
91 return queue_.handle;
92 }
93
94 // Device Caches
95
shader_layout_cache()96 inline vkapi::ShaderLayoutCache& shader_layout_cache() {
97 return adapter_ptr()->shader_layout_cache();
98 }
99
shader_cache()100 inline vkapi::ShaderCache& shader_cache() {
101 return adapter_ptr()->shader_cache();
102 }
103
pipeline_layout_cache()104 inline vkapi::PipelineLayoutCache& pipeline_layout_cache() {
105 return adapter_ptr()->pipeline_layout_cache();
106 }
107
pipeline_cache()108 inline vkapi::ComputePipelineCache& pipeline_cache() {
109 return adapter_ptr()->compute_pipeline_cache();
110 }
111
112 // Resource Pools
113
descriptor_pool()114 inline vkapi::DescriptorPool& descriptor_pool() {
115 return descriptor_pool_;
116 }
117
fences()118 inline vkapi::FencePool& fences() {
119 return fences_;
120 }
121
122 // Diagnostics
123
querypool()124 inline vkapi::QueryPool& querypool() {
125 return querypool_;
126 }
127
preferred_image_tiling()128 inline VkImageTiling preferred_image_tiling() {
129 return preferred_image_tiling_;
130 }
131
132 /*
133 * By default, the querypool attached to a Context instance is uninitialized.
134 * This function triggers the querypool to be created via vkCreateQueryPool.
135 */
136 void initialize_querypool();
137
138 /*
139 * Encodes a vkResetQueryPool command to the current command buffer, and reset
140 * the internal state of the querypool. If the querypool is not initialized
141 * this function is a no-op.
142 */
143 void cmd_reset_querypool();
144
145 /*
146 * Encodes a vkCmdWriteTimestamp command to the current command buffer and
147 * record some metadata about the shader that will be dispatched. If the
148 * querypool is not initialized this function is a no-op.
149 */
150 void report_shader_dispatch_start(
151 const std::string& shader_name,
152 const utils::uvec3& global_wg_size,
153 const utils::uvec3& local_wg_size,
154 const uint32_t dispatch_id = UINT32_MAX);
155
156 /*
157 * Encodes a vkCmdWriteTimstamp command to the current command buffer to
158 * record when the last shader that was dispatched has completed execution.
159 * If the querypool is not initialized this function is a no-op.
160 */
161 void report_shader_dispatch_end();
162
163 // Memory Management
164
register_buffer_cleanup(vkapi::VulkanBuffer & buffer)165 void register_buffer_cleanup(vkapi::VulkanBuffer& buffer) {
166 std::lock_guard<std::mutex> bufferlist_lock(buffer_clearlist_mutex_);
167 buffers_to_clear_.emplace_back(std::move(buffer));
168 }
169
register_image_cleanup(vkapi::VulkanImage & image)170 void register_image_cleanup(vkapi::VulkanImage& image) {
171 std::lock_guard<std::mutex> imagelist_lock(image_clearlist_mutex_);
172 images_to_clear_.emplace_back(std::move(image));
173 }
174
175 // GPU RPC
176
dispatch_lock()177 inline std::unique_lock<std::mutex> dispatch_lock() {
178 return std::unique_lock<std::mutex>(cmd_mutex_);
179 }
180
181 inline void set_cmd(bool reusable = false) {
182 if (!cmd_) {
183 cmd_ = command_pool_.get_new_cmd(reusable);
184 cmd_.begin();
185 }
186 }
187
188 vkapi::DescriptorSet get_descriptor_set(
189 const vkapi::ShaderInfo&,
190 const utils::uvec3&,
191 const vkapi::SpecVarList&);
192
get_descriptor_set(const vkapi::ShaderInfo & shader_descriptor,const utils::uvec3 & local_work_group_size)193 inline vkapi::DescriptorSet get_descriptor_set(
194 const vkapi::ShaderInfo& shader_descriptor,
195 const utils::uvec3& local_work_group_size) {
196 return get_descriptor_set(shader_descriptor, local_work_group_size, {});
197 }
198
199 void register_shader_dispatch(
200 const vkapi::DescriptorSet&,
201 vkapi::PipelineBarrier&,
202 const vkapi::ShaderInfo&,
203 const utils::uvec3&);
204
205 void register_blit(
206 vkapi::PipelineBarrier&,
207 vkapi::VulkanImage& src,
208 vkapi::VulkanImage& dst);
209
210 template <typename... Arguments>
211 bool submit_compute_job(
212 const vkapi::ShaderInfo&,
213 vkapi::PipelineBarrier&,
214 const utils::uvec3&,
215 const utils::uvec3&,
216 const vkapi::SpecVarList&,
217 VkFence fence_handle,
218 const uint32_t dispatch_id,
219 Arguments&&...);
220
221 void submit_cmd_to_gpu(
222 VkFence fence_handle = VK_NULL_HANDLE,
223 const bool final_use = false);
224
225 void flush();
226 };
227
228 bool available();
229
230 // The global runtime is retrieved using this function, where it is declared as
231 // a static local variable.
232 Context* context();
233
234 namespace detail {
235
arg_is_empty(bool & any_is_empty,const vkapi::VulkanBuffer & buffer)236 inline void arg_is_empty(
237 bool& any_is_empty,
238 const vkapi::VulkanBuffer& buffer) {
239 // bool(buffer) will evaluate to false if no memory has been allocated
240 any_is_empty = any_is_empty || !buffer;
241 }
242
arg_is_empty(bool & any_is_empty,const vkapi::VulkanImage & image)243 inline void arg_is_empty(bool& any_is_empty, const vkapi::VulkanImage& image) {
244 // bool(image) will evaluate to false if no memory has been allocated
245 any_is_empty = any_is_empty || !image;
246 }
247
arg_is_empty(bool & any_is_empty,const vkapi::BufferBindInfo & bind_info)248 inline void arg_is_empty(
249 bool& any_is_empty,
250 const vkapi::BufferBindInfo& bind_info) {
251 any_is_empty = any_is_empty || (bind_info.handle == VK_NULL_HANDLE);
252 }
253
254 /*
255 Reports if any VulkanBuffer or VulkanImage argument in a variadic argument
256 list does not have any memory associated with it.
257 */
258 template <typename... Arguments>
any_arg_is_empty(Arguments &&...arguments)259 inline bool any_arg_is_empty(Arguments&&... arguments) {
260 bool any_is_empty = false;
261 VK_UNUSED const int _[]{
262 0,
263 (arg_is_empty(any_is_empty, std::forward<Arguments>(arguments)), 0)...,
264 };
265
266 return any_is_empty;
267 }
268
269 template <size_t... Indices, typename... Arguments>
bind(vkapi::DescriptorSet & descriptor_set,const std::index_sequence<Indices...> &,Arguments &&...arguments)270 inline void bind(
271 vkapi::DescriptorSet& descriptor_set,
272 const std::index_sequence<Indices...>&,
273 Arguments&&... arguments) {
274 VK_UNUSED const int _[]{
275 0,
276 (descriptor_set.bind(Indices, std::forward<Arguments>(arguments)), 0)...,
277 };
278 }
279
280 } // namespace detail
281
282 /*
283 Records a compute shader dispatch into the current command buffer. If the
284 number of submit_*_job calls exceeds the configured frequency, or if a fence
285 is provided, then the command buffer is submitted to the GPU for execution.
286 Returns a bool indicating whether or not the function call resulted in a GPU
287 queue submission.
288 */
289 template <typename... Arguments>
submit_compute_job(const vkapi::ShaderInfo & shader,vkapi::PipelineBarrier & pipeline_barrier,const utils::uvec3 & global_work_group,const utils::uvec3 & local_work_group_size,const vkapi::SpecVarList & specialization_constants,VkFence fence_handle,const uint32_t dispatch_id,Arguments &&...arguments)290 inline bool Context::submit_compute_job(
291 const vkapi::ShaderInfo& shader,
292 vkapi::PipelineBarrier& pipeline_barrier,
293 const utils::uvec3& global_work_group,
294 const utils::uvec3& local_work_group_size,
295 const vkapi::SpecVarList& specialization_constants,
296 VkFence fence_handle,
297 const uint32_t dispatch_id,
298 Arguments&&... arguments) {
299 // If any of the provided arguments does not have memory associated with it,
300 // then exit early as there is no work to be done. However, if a fence has
301 // been passed the command buffer is not empty, then the current command
302 // buffer must still be submitted so that the fence can be signaled.
303 if (detail::any_arg_is_empty(arguments...)) {
304 if (fence_handle != VK_NULL_HANDLE && submit_count_ > 0) {
305 submit_cmd_to_gpu(fence_handle);
306 return true;
307 }
308 return false;
309 }
310
311 // Serialize recording to the shared command buffer. Do not initialize with a
312 // mutex just yet, since in some cases it will be externally managed.
313 std::unique_lock<std::mutex> cmd_lock;
314 // If a fence was passed, then assume that the host intends to sync with
315 // the GPU, implying there will be imminent calls to fence.wait() and flush().
316 // We therefore assume the mutex is externally managed in this case, and the
317 // calling thread has already locked the mutex prior to calling the function,
318 // and will release the mutex manually after calling flush(). This will
319 // prevent more dispatches from being recorded until we have flushed the
320 // Context.
321 if (fence_handle == VK_NULL_HANDLE) {
322 cmd_lock = std::unique_lock<std::mutex>(cmd_mutex_);
323 }
324
325 set_cmd();
326
327 report_shader_dispatch_start(
328 shader.kernel_name,
329 global_work_group,
330 local_work_group_size,
331 dispatch_id);
332
333 // Factor out template parameter independent code to minimize code bloat.
334 vkapi::DescriptorSet descriptor_set = get_descriptor_set(
335 shader, local_work_group_size, specialization_constants);
336
337 detail::bind(
338 descriptor_set,
339 std::index_sequence_for<Arguments...>{},
340 std::forward<Arguments>(arguments)...);
341
342 // Factor out template parameter independent code to minimize code bloat.
343 register_shader_dispatch(
344 descriptor_set, pipeline_barrier, shader, global_work_group);
345
346 report_shader_dispatch_end();
347
348 submit_count_++;
349 if (fence_handle != VK_NULL_HANDLE ||
350 submit_count_ >= config_.cmd_submit_frequency) {
351 submit_cmd_to_gpu(fence_handle);
352 return true;
353 }
354
355 return false;
356 }
357
358 } // namespace api
359 } // namespace vkcompute
360