xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Resource.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
4 
5 #ifdef USE_VULKAN_API
6 
7 #include <ATen/native/vulkan/api/vk_api.h>
8 
9 #include <ATen/native/vulkan/api/Allocator.h>
10 #include <ATen/native/vulkan/api/Types.h>
11 #include <ATen/native/vulkan/api/Utils.h>
12 
13 #include <mutex>
14 #include <ostream>
15 #include <stack>
16 #include <unordered_map>
17 
18 std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats);
19 
20 namespace at {
21 namespace native {
22 namespace vulkan {
23 namespace api {
24 
25 using MemoryAccessFlags = uint8_t;
26 
27 constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY =
28     VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT;
29 
30 enum MemoryAccessType : MemoryAccessFlags {
31   NONE = 0u << 0u,
32   READ = 1u << 0u,
33   WRITE = 1u << 1u,
34 };
35 
36 struct MemoryBarrier final {
37   VkMemoryBarrier handle;
38 
39   MemoryBarrier(
40       const VkAccessFlags src_access_flags,
41       const VkAccessFlags dst_access_flags);
42 };
43 
44 struct MemoryAllocation final {
45   explicit MemoryAllocation();
46 
47   explicit MemoryAllocation(
48       const VmaAllocator,
49       const VkMemoryRequirements&,
50       const VmaAllocationCreateInfo&);
51 
52   MemoryAllocation(const MemoryAllocation&) = delete;
53   MemoryAllocation& operator=(const MemoryAllocation&) = delete;
54 
55   MemoryAllocation(MemoryAllocation&&) noexcept;
56   MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
57 
58   ~MemoryAllocation();
59 
60   VkMemoryRequirements memory_requirements;
61   // The properties this allocation was created with
62   VmaAllocationCreateInfo create_info;
63   // The allocator object this was allocated from
64   VmaAllocator allocator;
65   // Handles to the allocated memory
66   VmaAllocation allocation;
67 
68   operator bool() const {
69     return (allocation != VK_NULL_HANDLE);
70   }
71 };
72 
73 class VulkanBuffer final {
74  public:
75   struct BufferProperties final {
76     VkDeviceSize size;
77     VkDeviceSize mem_offset;
78     VkDeviceSize mem_range;
79     VkBufferUsageFlags buffer_usage;
80   };
81 
82   explicit VulkanBuffer();
83 
84   explicit VulkanBuffer(
85       const VmaAllocator,
86       const VkDeviceSize,
87       const VmaAllocationCreateInfo&,
88       const VkBufferUsageFlags,
89       const bool allocate_memory = true);
90 
91   VulkanBuffer(const VulkanBuffer&) = delete;
92   VulkanBuffer& operator=(const VulkanBuffer&) = delete;
93 
94   VulkanBuffer(VulkanBuffer&&) noexcept;
95   VulkanBuffer& operator=(VulkanBuffer&&) noexcept;
96 
97   ~VulkanBuffer();
98 
99   struct Package final {
100     VkBuffer handle;
101     VkDeviceSize buffer_offset;
102     VkDeviceSize buffer_range;
103   };
104 
105   friend struct BufferMemoryBarrier;
106 
107  private:
108   BufferProperties buffer_properties_;
109   VmaAllocator allocator_;
110   MemoryAllocation memory_;
111   // Indicates whether the underlying memory is owned by this resource
112   bool owns_memory_;
113   VkBuffer handle_;
114 
115  public:
device()116   inline VkDevice device() const {
117     VmaAllocatorInfo allocator_info{};
118     vmaGetAllocatorInfo(allocator_, &allocator_info);
119     return allocator_info.device;
120   }
121 
vma_allocator()122   inline VmaAllocator vma_allocator() const {
123     return allocator_;
124   }
125 
allocation()126   inline VmaAllocation allocation() const {
127     return memory_.allocation;
128   }
129 
allocation_create_info()130   inline VmaAllocationCreateInfo allocation_create_info() const {
131     return VmaAllocationCreateInfo(memory_.create_info);
132   }
133 
handle()134   inline VkBuffer handle() const {
135     return handle_;
136   }
137 
mem_offset()138   inline VkDeviceSize mem_offset() const {
139     return buffer_properties_.mem_offset;
140   }
141 
mem_range()142   inline VkDeviceSize mem_range() const {
143     return buffer_properties_.mem_range;
144   }
145 
mem_size()146   inline VkDeviceSize mem_size() const {
147     return buffer_properties_.size;
148   }
149 
has_memory()150   inline bool has_memory() const {
151     return (memory_.allocation != VK_NULL_HANDLE);
152   }
153 
owns_memory()154   inline bool owns_memory() const {
155     return owns_memory_;
156   }
157 
158   operator bool() const {
159     return (handle_ != VK_NULL_HANDLE);
160   }
161 
bind_allocation(const MemoryAllocation & memory)162   inline void bind_allocation(const MemoryAllocation& memory) {
163     VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!");
164     VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_));
165     memory_.allocation = memory.allocation;
166   }
167 
168   VkMemoryRequirements get_memory_requirements() const;
169 };
170 
171 class MemoryMap final {
172  public:
173   explicit MemoryMap(
174       const VulkanBuffer& buffer,
175       const MemoryAccessFlags access);
176 
177   MemoryMap(const MemoryMap&) = delete;
178   MemoryMap& operator=(const MemoryMap&) = delete;
179 
180   MemoryMap(MemoryMap&&) noexcept;
181   MemoryMap& operator=(MemoryMap&&) = delete;
182 
183   ~MemoryMap();
184 
185  private:
186   uint8_t access_;
187   VmaAllocator allocator_;
188   VmaAllocation allocation_;
189   void* data_;
190   VkDeviceSize data_len_;
191 
192  public:
193   template <typename T>
data()194   T* data() {
195     return reinterpret_cast<T*>(data_);
196   }
197 
nbytes()198   inline size_t nbytes() {
199     return utils::safe_downcast<size_t>(data_len_);
200   }
201 
202   void invalidate();
203 };
204 
205 struct BufferMemoryBarrier final {
206   VkBufferMemoryBarrier handle;
207 
208   BufferMemoryBarrier(
209       const VkAccessFlags src_access_flags,
210       const VkAccessFlags dst_access_flags,
211       const VulkanBuffer& buffer);
212 };
213 
214 class ImageSampler final {
215  public:
216   struct Properties final {
217     VkFilter filter;
218     VkSamplerMipmapMode mipmap_mode;
219     VkSamplerAddressMode address_mode;
220     VkBorderColor border_color;
221   };
222 
223   explicit ImageSampler(VkDevice, const Properties&);
224 
225   ImageSampler(const ImageSampler&) = delete;
226   ImageSampler& operator=(const ImageSampler&) = delete;
227 
228   ImageSampler(ImageSampler&&) noexcept;
229   ImageSampler& operator=(ImageSampler&&) = delete;
230 
231   ~ImageSampler();
232 
233  private:
234   VkDevice device_;
235   VkSampler handle_;
236 
237  public:
handle()238   VkSampler handle() const {
239     return handle_;
240   }
241 
242   struct Hasher {
243     size_t operator()(const Properties&) const;
244   };
245 
246   // We need to define a custom swap function since this class
247   // does not allow for move assignment. The swap function will
248   // be used in the hash map.
249   friend void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept;
250 };
251 
252 class VulkanImage final {
253  public:
254   struct ImageProperties final {
255     VkImageType image_type;
256     VkFormat image_format;
257     VkExtent3D image_extents;
258     VkImageUsageFlags image_usage;
259   };
260 
261   struct ViewProperties final {
262     VkImageViewType view_type;
263     VkFormat view_format;
264   };
265 
266   using SamplerProperties = ImageSampler::Properties;
267 
268   struct Handles final {
269     VkImage image;
270     VkImageView image_view;
271     VkSampler sampler;
272   };
273 
274   explicit VulkanImage();
275 
276   explicit VulkanImage(
277       const VmaAllocator,
278       const VmaAllocationCreateInfo&,
279       const ImageProperties&,
280       const ViewProperties&,
281       const SamplerProperties&,
282       const VkImageLayout layout,
283       VkSampler,
284       const bool allocate_memory = true);
285 
286   VulkanImage(const VulkanImage&) = delete;
287   VulkanImage& operator=(const VulkanImage&) = delete;
288 
289   VulkanImage(VulkanImage&&) noexcept;
290   VulkanImage& operator=(VulkanImage&&) noexcept;
291 
292   ~VulkanImage();
293 
294   struct Package final {
295     VkImage handle;
296     VkImageLayout image_layout;
297     VkImageView image_view;
298     VkSampler image_sampler;
299   };
300 
301   friend struct ImageMemoryBarrier;
302 
303  private:
304   ImageProperties image_properties_;
305   ViewProperties view_properties_;
306   SamplerProperties sampler_properties_;
307   // The allocator object this was allocated from
308   VmaAllocator allocator_;
309   // Handles to the allocated memory
310   MemoryAllocation memory_;
311   // Indicates whether the underlying memory is owned by this resource
312   bool owns_memory_;
313   Handles handles_;
314   // Layout
315   VkImageLayout layout_;
316 
317  public:
318   void create_image_view();
319 
device()320   inline VkDevice device() const {
321     VmaAllocatorInfo allocator_info{};
322     vmaGetAllocatorInfo(allocator_, &allocator_info);
323     return allocator_info.device;
324   }
325 
vma_allocator()326   inline VmaAllocator vma_allocator() const {
327     return allocator_;
328   }
329 
allocation()330   inline VmaAllocation allocation() const {
331     return memory_.allocation;
332   }
333 
allocation_create_info()334   inline VmaAllocationCreateInfo allocation_create_info() const {
335     return VmaAllocationCreateInfo(memory_.create_info);
336   }
337 
format()338   inline VkFormat format() const {
339     return image_properties_.image_format;
340   }
341 
extents()342   inline VkExtent3D extents() const {
343     return image_properties_.image_extents;
344   }
345 
handle()346   inline VkImage handle() const {
347     return handles_.image;
348   }
349 
image_view()350   inline VkImageView image_view() const {
351     return handles_.image_view;
352   }
353 
sampler()354   inline VkSampler sampler() const {
355     return handles_.sampler;
356   }
357 
package()358   Package package() const {
359     return {
360         handles_.image,
361         layout_,
362         handles_.image_view,
363         handles_.sampler,
364     };
365   }
366 
layout()367   inline VkImageLayout layout() const {
368     return layout_;
369   }
370 
set_layout(const VkImageLayout layout)371   inline void set_layout(const VkImageLayout layout) {
372     layout_ = layout;
373   }
374 
has_memory()375   inline bool has_memory() const {
376     return (memory_.allocation != VK_NULL_HANDLE);
377   }
378 
owns_memory()379   inline bool owns_memory() const {
380     return owns_memory_;
381   }
382 
383   inline operator bool() const {
384     return (handles_.image != VK_NULL_HANDLE);
385   }
386 
bind_allocation(const MemoryAllocation & memory)387   inline void bind_allocation(const MemoryAllocation& memory) {
388     VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!");
389     VK_CHECK(vmaBindImageMemory(allocator_, memory.allocation, handles_.image));
390     memory_.allocation = memory.allocation;
391 
392     // Only create the image view if the image has been bound to memory
393     create_image_view();
394   }
395 
396   VkMemoryRequirements get_memory_requirements() const;
397 };
398 
399 struct ImageMemoryBarrier final {
400   VkImageMemoryBarrier handle;
401 
402   ImageMemoryBarrier(
403       const VkAccessFlags src_access_flags,
404       const VkAccessFlags dst_access_flags,
405       const VkImageLayout src_layout_flags,
406       const VkImageLayout dst_layout_flags,
407       const VulkanImage& image);
408 };
409 
410 class SamplerCache final {
411  public:
412   explicit SamplerCache(VkDevice device);
413 
414   SamplerCache(const SamplerCache&) = delete;
415   SamplerCache& operator=(const SamplerCache&) = delete;
416 
417   SamplerCache(SamplerCache&&) noexcept;
418   SamplerCache& operator=(SamplerCache&&) = delete;
419 
420   ~SamplerCache();
421 
422   using Key = ImageSampler::Properties;
423   using Value = ImageSampler;
424   using Hasher = ImageSampler::Hasher;
425 
426  private:
427   // Multiple threads could potentially be adding entries into the cache, so use
428   // a mutex to manage access
429   std::mutex cache_mutex_;
430 
431   VkDevice device_;
432   std::unordered_map<Key, Value, Hasher> cache_;
433 
434  public:
435   VkSampler retrieve(const Key&);
436   void purge();
437 };
438 
439 class MemoryAllocator final {
440  public:
441   explicit MemoryAllocator(
442       VkInstance instance,
443       VkPhysicalDevice physical_device,
444       VkDevice device);
445 
446   MemoryAllocator(const MemoryAllocator&) = delete;
447   MemoryAllocator& operator=(const MemoryAllocator&) = delete;
448 
449   MemoryAllocator(MemoryAllocator&&) noexcept;
450   MemoryAllocator& operator=(MemoryAllocator&&) = delete;
451 
452   ~MemoryAllocator();
453 
454  private:
455   VkInstance instance_;
456   VkPhysicalDevice physical_device_;
457   VkDevice device_;
458   VmaAllocator allocator_;
459 
460  public:
461   MemoryAllocation create_allocation(
462       const VkMemoryRequirements& memory_requirements,
463       const VmaAllocationCreateInfo& create_info);
464 
465   VulkanImage create_image(
466       const VkExtent3D&,
467       const VkFormat,
468       const VkImageType,
469       const VkImageViewType,
470       const VulkanImage::SamplerProperties&,
471       VkSampler,
472       const bool allow_transfer = false,
473       const bool allocate_memory = true);
474 
475   VulkanBuffer create_storage_buffer(
476       const VkDeviceSize,
477       const bool gpu_only = true,
478       const bool allocate_memory = true);
479 
480   VulkanBuffer create_staging_buffer(const VkDeviceSize);
481 
482   /*
483    * Create a uniform buffer with a specified size
484    */
485   VulkanBuffer create_uniform_buffer(const VkDeviceSize);
486 
487   /*
488    * Create a uniform buffer containing the data in an arbitrary struct
489    */
490   template <typename Block>
491   VulkanBuffer create_params_buffer(const Block& block);
492 
get_memory_statistics()493   VmaTotalStatistics get_memory_statistics() const {
494     VmaTotalStatistics stats = {};
495     vmaCalculateStatistics(allocator_, &stats);
496     return stats;
497   }
498 };
499 
500 class VulkanFence final {
501  public:
502   // TODO: This is required for the lazy allocation pattern in api/Tensor.
503   //       It will be disabled pending future refactors.
504   explicit VulkanFence();
505 
506   explicit VulkanFence(VkDevice);
507 
508   VulkanFence(const VulkanFence&) = delete;
509   VulkanFence& operator=(const VulkanFence&) = delete;
510 
511   VulkanFence(VulkanFence&&) noexcept;
512   VulkanFence& operator=(VulkanFence&&) noexcept;
513 
514   ~VulkanFence();
515 
516  private:
517   VkDevice device_;
518   VkFence handle_;
519   bool waiting_;
520 
521  public:
522   // Used to get the handle for a queue submission.
get_submit_handle()523   VkFence get_submit_handle() {
524     if (handle_ != VK_NULL_HANDLE) {
525       // Indicate we are now waiting for this fence to be signaled
526       waiting_ = true;
527     }
528     return handle_;
529   }
530 
handle()531   VkFence handle() {
532     return handle_;
533   }
534 
535   // Trigger a synchronous wait for the fence to be signaled
536   void wait();
537 
waiting()538   bool waiting() const {
539     return waiting_;
540   }
541 
542   operator bool() const {
543     return (VK_NULL_HANDLE != handle_);
544   }
545 };
546 
547 // A pool to track created Fences and reuse ones that are available.
548 // Only intended to be modified by one thread at a time.
549 struct FencePool final {
550   VkDevice device_;
551 
552   std::stack<VulkanFence> pool_;
553 
FencePoolfinal554   explicit FencePool(VkDevice device) : device_(device), pool_{} {}
555 
556   // Returns an rvalue reference to a fence, so that it can be moved
get_fencefinal557   inline VulkanFence get_fence() {
558     if (pool_.empty()) {
559       VulkanFence new_fence = VulkanFence(device_);
560       return new_fence;
561     }
562 
563     VulkanFence top_fence = std::move(pool_.top());
564     pool_.pop();
565 
566     return top_fence;
567   }
568 
569   // Marks the fence as available
return_fencefinal570   inline void return_fence(VulkanFence& fence) {
571     pool_.push(std::move(fence));
572   }
573 };
574 
575 //
576 // Impl
577 //
578 
579 template <typename Block>
create_params_buffer(const Block & block)580 inline VulkanBuffer MemoryAllocator::create_params_buffer(const Block& block) {
581   VulkanBuffer uniform_buffer = create_uniform_buffer(sizeof(Block));
582 
583   // Fill the uniform buffer with data in block
584   {
585     MemoryMap mapping(uniform_buffer, MemoryAccessType::WRITE);
586     Block* data_ptr = mapping.template data<Block>();
587 
588     *data_ptr = block;
589   }
590 
591   return uniform_buffer;
592 }
593 
594 } // namespace api
595 } // namespace vulkan
596 } // namespace native
597 } // namespace at
598 
599 #endif /* USE_VULKAN_API */
600