xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/api/Adapter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/api/Adapter.h>
2 
3 #include <bitset>
4 #include <cstring>
5 #include <iomanip>
6 #include <sstream>
7 #include <utility>
8 
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace api {
13 
PhysicalDevice(VkPhysicalDevice physical_device_handle)14 PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
15     : handle(physical_device_handle),
16       properties{},
17       memory_properties{},
18       queue_families{},
19       num_compute_queues(0),
20       has_unified_memory(false),
21       has_timestamps(properties.limits.timestampComputeAndGraphics),
22       timestamp_period(properties.limits.timestampPeriod) {
23   // Extract physical device properties
24   vkGetPhysicalDeviceProperties(handle, &properties);
25   vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);
26 
27   // Check if there are any memory types have both the HOST_VISIBLE and the
28   // DEVICE_LOCAL property flags
29   const VkMemoryPropertyFlags unified_memory_flags =
30       VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
31   for (size_t i = 0; i < memory_properties.memoryTypeCount; ++i) {
32     if (memory_properties.memoryTypes[i].propertyFlags | unified_memory_flags) {
33       has_unified_memory = true;
34       break;
35     }
36   }
37 
38   uint32_t queue_family_count = 0;
39   vkGetPhysicalDeviceQueueFamilyProperties(
40       handle, &queue_family_count, nullptr);
41 
42   queue_families.resize(queue_family_count);
43   vkGetPhysicalDeviceQueueFamilyProperties(
44       handle, &queue_family_count, queue_families.data());
45 
46   // Find the total number of compute queues
47   for (const VkQueueFamilyProperties& p : queue_families) {
48     // Check if this family has compute capability
49     if (p.queueFlags & VK_QUEUE_COMPUTE_BIT) {
50       num_compute_queues += p.queueCount;
51     }
52   }
53 }
54 
55 namespace {
56 
find_requested_device_extensions(VkPhysicalDevice physical_device,std::vector<const char * > & enabled_extensions,const std::vector<const char * > & requested_extensions)57 void find_requested_device_extensions(
58     VkPhysicalDevice physical_device,
59     std::vector<const char*>& enabled_extensions,
60     const std::vector<const char*>& requested_extensions) {
61   uint32_t device_extension_properties_count = 0;
62   VK_CHECK(vkEnumerateDeviceExtensionProperties(
63       physical_device, nullptr, &device_extension_properties_count, nullptr));
64   std::vector<VkExtensionProperties> device_extension_properties(
65       device_extension_properties_count);
66   VK_CHECK(vkEnumerateDeviceExtensionProperties(
67       physical_device,
68       nullptr,
69       &device_extension_properties_count,
70       device_extension_properties.data()));
71 
72   std::vector<const char*> enabled_device_extensions;
73 
74   for (const auto& requested_extension : requested_extensions) {
75     for (const auto& extension : device_extension_properties) {
76       if (strcmp(requested_extension, extension.extensionName) == 0) {
77         enabled_extensions.push_back(requested_extension);
78         break;
79       }
80     }
81   }
82 }
83 
create_logical_device(const PhysicalDevice & physical_device,const uint32_t num_queues_to_create,std::vector<Adapter::Queue> & queues,std::vector<uint32_t> & queue_usage)84 VkDevice create_logical_device(
85     const PhysicalDevice& physical_device,
86     const uint32_t num_queues_to_create,
87     std::vector<Adapter::Queue>& queues,
88     std::vector<uint32_t>& queue_usage) {
89   // Find compute queues up to the requested number of queues
90 
91   std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
92   queue_create_infos.reserve(num_queues_to_create);
93 
94   std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
95   queues_to_get.reserve(num_queues_to_create);
96 
97   uint32_t remaining_queues = num_queues_to_create;
98   for (uint32_t family_i = 0; family_i < physical_device.queue_families.size();
99        ++family_i) {
100     const VkQueueFamilyProperties& queue_properties =
101         physical_device.queue_families.at(family_i);
102     // Check if this family has compute capability
103     if (queue_properties.queueFlags & VK_QUEUE_COMPUTE_BIT) {
104       const uint32_t queues_to_init =
105           std::min(remaining_queues, queue_properties.queueCount);
106 
107       const std::vector<float> queue_priorities(queues_to_init, 1.0f);
108       queue_create_infos.push_back({
109           VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, // sType
110           nullptr, // pNext
111           0u, // flags
112           family_i, // queueFamilyIndex
113           queues_to_init, // queueCount
114           queue_priorities.data(), // pQueuePriorities
115       });
116 
117       for (size_t queue_i = 0; queue_i < queues_to_init; ++queue_i) {
118         // Use this to get the queue handle once device is created
119         queues_to_get.emplace_back(family_i, queue_i);
120       }
121       remaining_queues -= queues_to_init;
122     }
123     if (remaining_queues == 0) {
124       break;
125     }
126   }
127 
128   queues.reserve(queues_to_get.size());
129   queue_usage.reserve(queues_to_get.size());
130 
131   // Create the VkDevice
132 
133   std::vector<const char*> requested_device_extensions{
134 #ifdef VK_KHR_portability_subset
135       VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
136 #endif /* VK_KHR_portability_subset */
137   };
138 
139   std::vector<const char*> enabled_device_extensions;
140   find_requested_device_extensions(
141       physical_device.handle,
142       enabled_device_extensions,
143       requested_device_extensions);
144 
145   const VkDeviceCreateInfo device_create_info{
146       VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
147       nullptr, // pNext
148       0u, // flags
149       static_cast<uint32_t>(queue_create_infos.size()), // queueCreateInfoCount
150       queue_create_infos.data(), // pQueueCreateInfos
151       0u, // enabledLayerCount
152       nullptr, // ppEnabledLayerNames
153       static_cast<uint32_t>(
154           enabled_device_extensions.size()), // enabledExtensionCount
155       enabled_device_extensions.data(), // ppEnabledExtensionNames
156       nullptr, // pEnabledFeatures
157   };
158 
159   VkDevice handle = nullptr;
160   VK_CHECK(vkCreateDevice(
161       physical_device.handle, &device_create_info, nullptr, &handle));
162 
163 #ifdef USE_VULKAN_VOLK
164   volkLoadDevice(handle);
165 #endif /* USE_VULKAN_VOLK */
166 
167   // Obtain handles for the created queues and initialize queue usage heuristic
168 
169   for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
170     VkQueue queue_handle = VK_NULL_HANDLE;
171     VkQueueFlags flags =
172         physical_device.queue_families.at(queue_idx.first).queueFlags;
173     vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
174     queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
175     // Initial usage value
176     queue_usage.push_back(0);
177   }
178 
179   return handle;
180 }
181 
182 // Print utils
183 
get_device_type_str(const VkPhysicalDeviceType type)184 std::string get_device_type_str(const VkPhysicalDeviceType type) {
185   switch (type) {
186     case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
187       return "INTEGRATED_GPU";
188     case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
189       return "DISCRETE_GPU";
190     case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
191       return "VIRTUAL_GPU";
192     case VK_PHYSICAL_DEVICE_TYPE_CPU:
193       return "CPU";
194     default:
195       return "UNKNOWN";
196   }
197 }
198 
get_memory_properties_str(const VkMemoryPropertyFlags flags)199 std::string get_memory_properties_str(const VkMemoryPropertyFlags flags) {
200   std::bitset<10> values(flags);
201   std::stringstream ss("|");
202   if (values[0]) {
203     ss << " DEVICE_LOCAL |";
204   }
205   if (values[1]) {
206     ss << " HOST_VISIBLE |";
207   }
208   if (values[2]) {
209     ss << " HOST_COHERENT |";
210   }
211   if (values[3]) {
212     ss << " HOST_CACHED |";
213   }
214   if (values[4]) {
215     ss << " LAZILY_ALLOCATED |";
216   }
217 
218   return ss.str();
219 }
220 
get_queue_family_properties_str(const VkQueueFlags flags)221 std::string get_queue_family_properties_str(const VkQueueFlags flags) {
222   std::bitset<10> values(flags);
223   std::stringstream ss("|");
224   if (values[0]) {
225     ss << " GRAPHICS |";
226   }
227   if (values[1]) {
228     ss << " COMPUTE |";
229   }
230   if (values[2]) {
231     ss << " TRANSFER |";
232   }
233 
234   return ss.str();
235 }
236 
237 } // namespace
238 
239 //
240 // DeviceHandle
241 //
242 
DeviceHandle(VkDevice device)243 DeviceHandle::DeviceHandle(VkDevice device) : handle_(device) {}
244 
DeviceHandle(DeviceHandle && other)245 DeviceHandle::DeviceHandle(DeviceHandle&& other) noexcept
246     : handle_(other.handle_) {
247   other.handle_ = VK_NULL_HANDLE;
248 }
249 
~DeviceHandle()250 DeviceHandle::~DeviceHandle() {
251   if (VK_NULL_HANDLE == handle_) {
252     return;
253   }
254   vkDestroyDevice(handle_, nullptr);
255 }
256 
257 //
258 // Adapter
259 //
260 
Adapter(VkInstance instance,PhysicalDevice physical_device,const uint32_t num_queues)261 Adapter::Adapter(
262     VkInstance instance,
263     PhysicalDevice physical_device,
264     const uint32_t num_queues)
265     : queue_usage_mutex_{},
266       physical_device_(std::move(physical_device)),
267       queues_{},
268       queue_usage_{},
269       queue_mutexes_{},
270       instance_(instance),
271       device_(create_logical_device(
272           physical_device_,
273           num_queues,
274           queues_,
275           queue_usage_)),
276       shader_layout_cache_(device_.handle_),
277       shader_cache_(device_.handle_),
278       pipeline_layout_cache_(device_.handle_),
279       compute_pipeline_cache_(device_.handle_),
280       sampler_cache_(device_.handle_),
281       vma_(instance_, physical_device_.handle, device_.handle_) {}
282 
request_queue()283 Adapter::Queue Adapter::request_queue() {
284   // Lock the mutex as multiple threads can request a queue at the same time
285   std::lock_guard<std::mutex> lock(queue_usage_mutex_);
286 
287   uint32_t min_usage = UINT32_MAX;
288   uint32_t min_used_i = 0;
289   for (size_t i = 0; i < queues_.size(); ++i) {
290     if (queue_usage_[i] < min_usage) {
291       min_used_i = i;
292       min_usage = queue_usage_[i];
293     }
294   }
295   queue_usage_[min_used_i] += 1;
296 
297   return queues_[min_used_i];
298 }
299 
return_queue(Adapter::Queue & compute_queue)300 void Adapter::return_queue(Adapter::Queue& compute_queue) {
301   for (size_t i = 0; i < queues_.size(); ++i) {
302     if ((queues_[i].family_index == compute_queue.family_index) &&
303         (queues_[i].queue_index == compute_queue.queue_index)) {
304       std::lock_guard<std::mutex> lock(queue_usage_mutex_);
305       queue_usage_[i] -= 1;
306       break;
307     }
308   }
309 }
310 
submit_cmd(const Adapter::Queue & device_queue,VkCommandBuffer cmd,VkFence fence)311 void Adapter::submit_cmd(
312     const Adapter::Queue& device_queue,
313     VkCommandBuffer cmd,
314     VkFence fence) {
315   const VkSubmitInfo submit_info{
316       VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
317       nullptr, // pNext
318       0u, // waitSemaphoreCount
319       nullptr, // pWaitSemaphores
320       nullptr, // pWaitDstStageMask
321       1u, // commandBufferCount
322       &cmd, // pCommandBuffers
323       0u, // signalSemaphoreCount
324       nullptr, // pSignalSemaphores
325   };
326 
327   std::lock_guard<std::mutex> queue_lock(
328       queue_mutexes_[device_queue.queue_index % NUM_QUEUE_MUTEXES]);
329 
330   VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence));
331 }
332 
submit_cmds(const Adapter::Queue & device_queue,const std::vector<VkCommandBuffer> & cmds,VkFence fence)333 void Adapter::submit_cmds(
334     const Adapter::Queue& device_queue,
335     const std::vector<VkCommandBuffer>& cmds,
336     VkFence fence) {
337   const VkSubmitInfo submit_info{
338       VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
339       nullptr, // pNext
340       0u, // waitSemaphoreCount
341       nullptr, // pWaitSemaphores
342       nullptr, // pWaitDstStageMask
343       utils::safe_downcast<uint32_t>(cmds.size()), // commandBufferCount
344       cmds.data(), // pCommandBuffers
345       0u, // signalSemaphoreCount
346       nullptr, // pSignalSemaphores
347   };
348 
349   VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence));
350 }
351 
stringize() const352 std::string Adapter::stringize() const {
353   std::stringstream ss;
354 
355   VkPhysicalDeviceProperties properties = physical_device_.properties;
356   uint32_t v_major = VK_VERSION_MAJOR(properties.apiVersion);
357   uint32_t v_minor = VK_VERSION_MINOR(properties.apiVersion);
358   std::string device_type = get_device_type_str(properties.deviceType);
359   VkPhysicalDeviceLimits limits = properties.limits;
360 
361   ss << "{" << std::endl;
362   ss << "  Physical Device Info {" << std::endl;
363   ss << "    apiVersion:    " << v_major << "." << v_minor << std::endl;
364   ss << "    driverversion: " << properties.driverVersion << std::endl;
365   ss << "    deviceType:    " << device_type << std::endl;
366   ss << "    deviceName:    " << properties.deviceName << std::endl;
367 
368 #define PRINT_LIMIT_PROP(name)                                         \
369   ss << "      " << std::left << std::setw(36) << #name << limits.name \
370      << std::endl;
371 
372 #define PRINT_LIMIT_PROP_VEC3(name)                                       \
373   ss << "      " << std::left << std::setw(36) << #name << limits.name[0] \
374      << "," << limits.name[1] << "," << limits.name[2] << std::endl;
375 
376   ss << "    Physical Device Limits {" << std::endl;
377   PRINT_LIMIT_PROP(maxImageDimension1D);
378   PRINT_LIMIT_PROP(maxImageDimension2D);
379   PRINT_LIMIT_PROP(maxImageDimension3D);
380   PRINT_LIMIT_PROP(maxTexelBufferElements);
381   PRINT_LIMIT_PROP(maxPushConstantsSize);
382   PRINT_LIMIT_PROP(maxMemoryAllocationCount);
383   PRINT_LIMIT_PROP(maxSamplerAllocationCount);
384   PRINT_LIMIT_PROP(maxComputeSharedMemorySize);
385   PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount);
386   PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations);
387   PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize);
388   ss << "    }" << std::endl;
389   ss << "  }" << std::endl;
390   ;
391 
392   const VkPhysicalDeviceMemoryProperties& mem_props =
393       physical_device_.memory_properties;
394 
395   ss << "  Memory Info {" << std::endl;
396   ss << "    Memory Types [" << std::endl;
397   for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
398     ss << "      "
399        << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
400        << get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags)
401        << std::endl;
402   }
403   ss << "    ]" << std::endl;
404   ss << "    Memory Heaps [" << std::endl;
405   for (size_t i = 0; i < mem_props.memoryHeapCount; ++i) {
406     ss << "      " << mem_props.memoryHeaps[i].size << std::endl;
407   }
408   ss << "    ]" << std::endl;
409   ss << "  }" << std::endl;
410 
411   ss << "  Queue Families {" << std::endl;
412   for (const VkQueueFamilyProperties& queue_family_props :
413        physical_device_.queue_families) {
414     ss << "    (" << queue_family_props.queueCount << " Queues) "
415        << get_queue_family_properties_str(queue_family_props.queueFlags)
416        << std::endl;
417   }
418   ss << "  }" << std::endl;
419   ss << "  VkDevice: " << device_.handle_ << std::endl;
420   ss << "  Compute Queues [" << std::endl;
421   for (const Adapter::Queue& compute_queue : queues_) {
422     ss << "    Family " << compute_queue.family_index << ", Queue "
423        << compute_queue.queue_index << ": " << compute_queue.handle
424        << std::endl;
425     ;
426   }
427   ss << "  ]" << std::endl;
428   ss << "}";
429 
430   return ss.str();
431 }
432 
operator <<(std::ostream & os,const Adapter & adapter)433 std::ostream& operator<<(std::ostream& os, const Adapter& adapter) {
434   os << adapter.stringize() << std::endl;
435   return os;
436 }
437 
438 } // namespace api
439 } // namespace vulkan
440 } // namespace native
441 } // namespace at
442