1 #pragma once
2
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4
5 #ifdef USE_VULKAN_API
6
7 #include <ATen/native/vulkan/api/vk_api.h>
8
9 #include <ATen/native/vulkan/api/Allocator.h>
10 #include <ATen/native/vulkan/api/Types.h>
11 #include <ATen/native/vulkan/api/Utils.h>
12
13 #include <mutex>
14 #include <ostream>
15 #include <stack>
16 #include <unordered_map>
17
18 std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats);
19
20 namespace at {
21 namespace native {
22 namespace vulkan {
23 namespace api {
24
25 using MemoryAccessFlags = uint8_t;
26
27 constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY =
28 VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT;
29
30 enum MemoryAccessType : MemoryAccessFlags {
31 NONE = 0u << 0u,
32 READ = 1u << 0u,
33 WRITE = 1u << 1u,
34 };
35
36 struct MemoryBarrier final {
37 VkMemoryBarrier handle;
38
39 MemoryBarrier(
40 const VkAccessFlags src_access_flags,
41 const VkAccessFlags dst_access_flags);
42 };
43
44 struct MemoryAllocation final {
45 explicit MemoryAllocation();
46
47 explicit MemoryAllocation(
48 const VmaAllocator,
49 const VkMemoryRequirements&,
50 const VmaAllocationCreateInfo&);
51
52 MemoryAllocation(const MemoryAllocation&) = delete;
53 MemoryAllocation& operator=(const MemoryAllocation&) = delete;
54
55 MemoryAllocation(MemoryAllocation&&) noexcept;
56 MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
57
58 ~MemoryAllocation();
59
60 VkMemoryRequirements memory_requirements;
61 // The properties this allocation was created with
62 VmaAllocationCreateInfo create_info;
63 // The allocator object this was allocated from
64 VmaAllocator allocator;
65 // Handles to the allocated memory
66 VmaAllocation allocation;
67
68 operator bool() const {
69 return (allocation != VK_NULL_HANDLE);
70 }
71 };
72
73 class VulkanBuffer final {
74 public:
75 struct BufferProperties final {
76 VkDeviceSize size;
77 VkDeviceSize mem_offset;
78 VkDeviceSize mem_range;
79 VkBufferUsageFlags buffer_usage;
80 };
81
82 explicit VulkanBuffer();
83
84 explicit VulkanBuffer(
85 const VmaAllocator,
86 const VkDeviceSize,
87 const VmaAllocationCreateInfo&,
88 const VkBufferUsageFlags,
89 const bool allocate_memory = true);
90
91 VulkanBuffer(const VulkanBuffer&) = delete;
92 VulkanBuffer& operator=(const VulkanBuffer&) = delete;
93
94 VulkanBuffer(VulkanBuffer&&) noexcept;
95 VulkanBuffer& operator=(VulkanBuffer&&) noexcept;
96
97 ~VulkanBuffer();
98
99 struct Package final {
100 VkBuffer handle;
101 VkDeviceSize buffer_offset;
102 VkDeviceSize buffer_range;
103 };
104
105 friend struct BufferMemoryBarrier;
106
107 private:
108 BufferProperties buffer_properties_;
109 VmaAllocator allocator_;
110 MemoryAllocation memory_;
111 // Indicates whether the underlying memory is owned by this resource
112 bool owns_memory_;
113 VkBuffer handle_;
114
115 public:
device()116 inline VkDevice device() const {
117 VmaAllocatorInfo allocator_info{};
118 vmaGetAllocatorInfo(allocator_, &allocator_info);
119 return allocator_info.device;
120 }
121
vma_allocator()122 inline VmaAllocator vma_allocator() const {
123 return allocator_;
124 }
125
allocation()126 inline VmaAllocation allocation() const {
127 return memory_.allocation;
128 }
129
allocation_create_info()130 inline VmaAllocationCreateInfo allocation_create_info() const {
131 return VmaAllocationCreateInfo(memory_.create_info);
132 }
133
handle()134 inline VkBuffer handle() const {
135 return handle_;
136 }
137
mem_offset()138 inline VkDeviceSize mem_offset() const {
139 return buffer_properties_.mem_offset;
140 }
141
mem_range()142 inline VkDeviceSize mem_range() const {
143 return buffer_properties_.mem_range;
144 }
145
mem_size()146 inline VkDeviceSize mem_size() const {
147 return buffer_properties_.size;
148 }
149
has_memory()150 inline bool has_memory() const {
151 return (memory_.allocation != VK_NULL_HANDLE);
152 }
153
owns_memory()154 inline bool owns_memory() const {
155 return owns_memory_;
156 }
157
158 operator bool() const {
159 return (handle_ != VK_NULL_HANDLE);
160 }
161
bind_allocation(const MemoryAllocation & memory)162 inline void bind_allocation(const MemoryAllocation& memory) {
163 VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!");
164 VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_));
165 memory_.allocation = memory.allocation;
166 }
167
168 VkMemoryRequirements get_memory_requirements() const;
169 };
170
171 class MemoryMap final {
172 public:
173 explicit MemoryMap(
174 const VulkanBuffer& buffer,
175 const MemoryAccessFlags access);
176
177 MemoryMap(const MemoryMap&) = delete;
178 MemoryMap& operator=(const MemoryMap&) = delete;
179
180 MemoryMap(MemoryMap&&) noexcept;
181 MemoryMap& operator=(MemoryMap&&) = delete;
182
183 ~MemoryMap();
184
185 private:
186 uint8_t access_;
187 VmaAllocator allocator_;
188 VmaAllocation allocation_;
189 void* data_;
190 VkDeviceSize data_len_;
191
192 public:
193 template <typename T>
data()194 T* data() {
195 return reinterpret_cast<T*>(data_);
196 }
197
nbytes()198 inline size_t nbytes() {
199 return utils::safe_downcast<size_t>(data_len_);
200 }
201
202 void invalidate();
203 };
204
205 struct BufferMemoryBarrier final {
206 VkBufferMemoryBarrier handle;
207
208 BufferMemoryBarrier(
209 const VkAccessFlags src_access_flags,
210 const VkAccessFlags dst_access_flags,
211 const VulkanBuffer& buffer);
212 };
213
214 class ImageSampler final {
215 public:
216 struct Properties final {
217 VkFilter filter;
218 VkSamplerMipmapMode mipmap_mode;
219 VkSamplerAddressMode address_mode;
220 VkBorderColor border_color;
221 };
222
223 explicit ImageSampler(VkDevice, const Properties&);
224
225 ImageSampler(const ImageSampler&) = delete;
226 ImageSampler& operator=(const ImageSampler&) = delete;
227
228 ImageSampler(ImageSampler&&) noexcept;
229 ImageSampler& operator=(ImageSampler&&) = delete;
230
231 ~ImageSampler();
232
233 private:
234 VkDevice device_;
235 VkSampler handle_;
236
237 public:
handle()238 VkSampler handle() const {
239 return handle_;
240 }
241
242 struct Hasher {
243 size_t operator()(const Properties&) const;
244 };
245
246 // We need to define a custom swap function since this class
247 // does not allow for move assignment. The swap function will
248 // be used in the hash map.
249 friend void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept;
250 };
251
252 class VulkanImage final {
253 public:
254 struct ImageProperties final {
255 VkImageType image_type;
256 VkFormat image_format;
257 VkExtent3D image_extents;
258 VkImageUsageFlags image_usage;
259 };
260
261 struct ViewProperties final {
262 VkImageViewType view_type;
263 VkFormat view_format;
264 };
265
266 using SamplerProperties = ImageSampler::Properties;
267
268 struct Handles final {
269 VkImage image;
270 VkImageView image_view;
271 VkSampler sampler;
272 };
273
274 explicit VulkanImage();
275
276 explicit VulkanImage(
277 const VmaAllocator,
278 const VmaAllocationCreateInfo&,
279 const ImageProperties&,
280 const ViewProperties&,
281 const SamplerProperties&,
282 const VkImageLayout layout,
283 VkSampler,
284 const bool allocate_memory = true);
285
286 VulkanImage(const VulkanImage&) = delete;
287 VulkanImage& operator=(const VulkanImage&) = delete;
288
289 VulkanImage(VulkanImage&&) noexcept;
290 VulkanImage& operator=(VulkanImage&&) noexcept;
291
292 ~VulkanImage();
293
294 struct Package final {
295 VkImage handle;
296 VkImageLayout image_layout;
297 VkImageView image_view;
298 VkSampler image_sampler;
299 };
300
301 friend struct ImageMemoryBarrier;
302
303 private:
304 ImageProperties image_properties_;
305 ViewProperties view_properties_;
306 SamplerProperties sampler_properties_;
307 // The allocator object this was allocated from
308 VmaAllocator allocator_;
309 // Handles to the allocated memory
310 MemoryAllocation memory_;
311 // Indicates whether the underlying memory is owned by this resource
312 bool owns_memory_;
313 Handles handles_;
314 // Layout
315 VkImageLayout layout_;
316
317 public:
318 void create_image_view();
319
device()320 inline VkDevice device() const {
321 VmaAllocatorInfo allocator_info{};
322 vmaGetAllocatorInfo(allocator_, &allocator_info);
323 return allocator_info.device;
324 }
325
vma_allocator()326 inline VmaAllocator vma_allocator() const {
327 return allocator_;
328 }
329
allocation()330 inline VmaAllocation allocation() const {
331 return memory_.allocation;
332 }
333
allocation_create_info()334 inline VmaAllocationCreateInfo allocation_create_info() const {
335 return VmaAllocationCreateInfo(memory_.create_info);
336 }
337
format()338 inline VkFormat format() const {
339 return image_properties_.image_format;
340 }
341
extents()342 inline VkExtent3D extents() const {
343 return image_properties_.image_extents;
344 }
345
handle()346 inline VkImage handle() const {
347 return handles_.image;
348 }
349
image_view()350 inline VkImageView image_view() const {
351 return handles_.image_view;
352 }
353
sampler()354 inline VkSampler sampler() const {
355 return handles_.sampler;
356 }
357
package()358 Package package() const {
359 return {
360 handles_.image,
361 layout_,
362 handles_.image_view,
363 handles_.sampler,
364 };
365 }
366
layout()367 inline VkImageLayout layout() const {
368 return layout_;
369 }
370
set_layout(const VkImageLayout layout)371 inline void set_layout(const VkImageLayout layout) {
372 layout_ = layout;
373 }
374
has_memory()375 inline bool has_memory() const {
376 return (memory_.allocation != VK_NULL_HANDLE);
377 }
378
owns_memory()379 inline bool owns_memory() const {
380 return owns_memory_;
381 }
382
383 inline operator bool() const {
384 return (handles_.image != VK_NULL_HANDLE);
385 }
386
bind_allocation(const MemoryAllocation & memory)387 inline void bind_allocation(const MemoryAllocation& memory) {
388 VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!");
389 VK_CHECK(vmaBindImageMemory(allocator_, memory.allocation, handles_.image));
390 memory_.allocation = memory.allocation;
391
392 // Only create the image view if the image has been bound to memory
393 create_image_view();
394 }
395
396 VkMemoryRequirements get_memory_requirements() const;
397 };
398
399 struct ImageMemoryBarrier final {
400 VkImageMemoryBarrier handle;
401
402 ImageMemoryBarrier(
403 const VkAccessFlags src_access_flags,
404 const VkAccessFlags dst_access_flags,
405 const VkImageLayout src_layout_flags,
406 const VkImageLayout dst_layout_flags,
407 const VulkanImage& image);
408 };
409
410 class SamplerCache final {
411 public:
412 explicit SamplerCache(VkDevice device);
413
414 SamplerCache(const SamplerCache&) = delete;
415 SamplerCache& operator=(const SamplerCache&) = delete;
416
417 SamplerCache(SamplerCache&&) noexcept;
418 SamplerCache& operator=(SamplerCache&&) = delete;
419
420 ~SamplerCache();
421
422 using Key = ImageSampler::Properties;
423 using Value = ImageSampler;
424 using Hasher = ImageSampler::Hasher;
425
426 private:
427 // Multiple threads could potentially be adding entries into the cache, so use
428 // a mutex to manage access
429 std::mutex cache_mutex_;
430
431 VkDevice device_;
432 std::unordered_map<Key, Value, Hasher> cache_;
433
434 public:
435 VkSampler retrieve(const Key&);
436 void purge();
437 };
438
439 class MemoryAllocator final {
440 public:
441 explicit MemoryAllocator(
442 VkInstance instance,
443 VkPhysicalDevice physical_device,
444 VkDevice device);
445
446 MemoryAllocator(const MemoryAllocator&) = delete;
447 MemoryAllocator& operator=(const MemoryAllocator&) = delete;
448
449 MemoryAllocator(MemoryAllocator&&) noexcept;
450 MemoryAllocator& operator=(MemoryAllocator&&) = delete;
451
452 ~MemoryAllocator();
453
454 private:
455 VkInstance instance_;
456 VkPhysicalDevice physical_device_;
457 VkDevice device_;
458 VmaAllocator allocator_;
459
460 public:
461 MemoryAllocation create_allocation(
462 const VkMemoryRequirements& memory_requirements,
463 const VmaAllocationCreateInfo& create_info);
464
465 VulkanImage create_image(
466 const VkExtent3D&,
467 const VkFormat,
468 const VkImageType,
469 const VkImageViewType,
470 const VulkanImage::SamplerProperties&,
471 VkSampler,
472 const bool allow_transfer = false,
473 const bool allocate_memory = true);
474
475 VulkanBuffer create_storage_buffer(
476 const VkDeviceSize,
477 const bool gpu_only = true,
478 const bool allocate_memory = true);
479
480 VulkanBuffer create_staging_buffer(const VkDeviceSize);
481
482 /*
483 * Create a uniform buffer with a specified size
484 */
485 VulkanBuffer create_uniform_buffer(const VkDeviceSize);
486
487 /*
488 * Create a uniform buffer containing the data in an arbitrary struct
489 */
490 template <typename Block>
491 VulkanBuffer create_params_buffer(const Block& block);
492
get_memory_statistics()493 VmaTotalStatistics get_memory_statistics() const {
494 VmaTotalStatistics stats = {};
495 vmaCalculateStatistics(allocator_, &stats);
496 return stats;
497 }
498 };
499
500 class VulkanFence final {
501 public:
502 // TODO: This is required for the lazy allocation pattern in api/Tensor.
503 // It will be disabled pending future refactors.
504 explicit VulkanFence();
505
506 explicit VulkanFence(VkDevice);
507
508 VulkanFence(const VulkanFence&) = delete;
509 VulkanFence& operator=(const VulkanFence&) = delete;
510
511 VulkanFence(VulkanFence&&) noexcept;
512 VulkanFence& operator=(VulkanFence&&) noexcept;
513
514 ~VulkanFence();
515
516 private:
517 VkDevice device_;
518 VkFence handle_;
519 bool waiting_;
520
521 public:
522 // Used to get the handle for a queue submission.
get_submit_handle()523 VkFence get_submit_handle() {
524 if (handle_ != VK_NULL_HANDLE) {
525 // Indicate we are now waiting for this fence to be signaled
526 waiting_ = true;
527 }
528 return handle_;
529 }
530
handle()531 VkFence handle() {
532 return handle_;
533 }
534
535 // Trigger a synchronous wait for the fence to be signaled
536 void wait();
537
waiting()538 bool waiting() const {
539 return waiting_;
540 }
541
542 operator bool() const {
543 return (VK_NULL_HANDLE != handle_);
544 }
545 };
546
547 // A pool to track created Fences and reuse ones that are available.
548 // Only intended to be modified by one thread at a time.
549 struct FencePool final {
550 VkDevice device_;
551
552 std::stack<VulkanFence> pool_;
553
FencePoolfinal554 explicit FencePool(VkDevice device) : device_(device), pool_{} {}
555
556 // Returns an rvalue reference to a fence, so that it can be moved
get_fencefinal557 inline VulkanFence get_fence() {
558 if (pool_.empty()) {
559 VulkanFence new_fence = VulkanFence(device_);
560 return new_fence;
561 }
562
563 VulkanFence top_fence = std::move(pool_.top());
564 pool_.pop();
565
566 return top_fence;
567 }
568
569 // Marks the fence as available
return_fencefinal570 inline void return_fence(VulkanFence& fence) {
571 pool_.push(std::move(fence));
572 }
573 };
574
575 //
576 // Impl
577 //
578
579 template <typename Block>
create_params_buffer(const Block & block)580 inline VulkanBuffer MemoryAllocator::create_params_buffer(const Block& block) {
581 VulkanBuffer uniform_buffer = create_uniform_buffer(sizeof(Block));
582
583 // Fill the uniform buffer with data in block
584 {
585 MemoryMap mapping(uniform_buffer, MemoryAccessType::WRITE);
586 Block* data_ptr = mapping.template data<Block>();
587
588 *data_ptr = block;
589 }
590
591 return uniform_buffer;
592 }
593
594 } // namespace api
595 } // namespace vulkan
596 } // namespace native
597 } // namespace at
598
599 #endif /* USE_VULKAN_API */
600