1 #include <ATen/native/vulkan/api/Adapter.h>
2
3 #include <bitset>
4 #include <cstring>
5 #include <iomanip>
6 #include <sstream>
7 #include <utility>
8
9 namespace at {
10 namespace native {
11 namespace vulkan {
12 namespace api {
13
PhysicalDevice(VkPhysicalDevice physical_device_handle)14 PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
15 : handle(physical_device_handle),
16 properties{},
17 memory_properties{},
18 queue_families{},
19 num_compute_queues(0),
20 has_unified_memory(false),
21 has_timestamps(properties.limits.timestampComputeAndGraphics),
22 timestamp_period(properties.limits.timestampPeriod) {
23 // Extract physical device properties
24 vkGetPhysicalDeviceProperties(handle, &properties);
25 vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties);
26
27 // Check if there are any memory types have both the HOST_VISIBLE and the
28 // DEVICE_LOCAL property flags
29 const VkMemoryPropertyFlags unified_memory_flags =
30 VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
31 for (size_t i = 0; i < memory_properties.memoryTypeCount; ++i) {
32 if (memory_properties.memoryTypes[i].propertyFlags | unified_memory_flags) {
33 has_unified_memory = true;
34 break;
35 }
36 }
37
38 uint32_t queue_family_count = 0;
39 vkGetPhysicalDeviceQueueFamilyProperties(
40 handle, &queue_family_count, nullptr);
41
42 queue_families.resize(queue_family_count);
43 vkGetPhysicalDeviceQueueFamilyProperties(
44 handle, &queue_family_count, queue_families.data());
45
46 // Find the total number of compute queues
47 for (const VkQueueFamilyProperties& p : queue_families) {
48 // Check if this family has compute capability
49 if (p.queueFlags & VK_QUEUE_COMPUTE_BIT) {
50 num_compute_queues += p.queueCount;
51 }
52 }
53 }
54
55 namespace {
56
find_requested_device_extensions(VkPhysicalDevice physical_device,std::vector<const char * > & enabled_extensions,const std::vector<const char * > & requested_extensions)57 void find_requested_device_extensions(
58 VkPhysicalDevice physical_device,
59 std::vector<const char*>& enabled_extensions,
60 const std::vector<const char*>& requested_extensions) {
61 uint32_t device_extension_properties_count = 0;
62 VK_CHECK(vkEnumerateDeviceExtensionProperties(
63 physical_device, nullptr, &device_extension_properties_count, nullptr));
64 std::vector<VkExtensionProperties> device_extension_properties(
65 device_extension_properties_count);
66 VK_CHECK(vkEnumerateDeviceExtensionProperties(
67 physical_device,
68 nullptr,
69 &device_extension_properties_count,
70 device_extension_properties.data()));
71
72 std::vector<const char*> enabled_device_extensions;
73
74 for (const auto& requested_extension : requested_extensions) {
75 for (const auto& extension : device_extension_properties) {
76 if (strcmp(requested_extension, extension.extensionName) == 0) {
77 enabled_extensions.push_back(requested_extension);
78 break;
79 }
80 }
81 }
82 }
83
create_logical_device(const PhysicalDevice & physical_device,const uint32_t num_queues_to_create,std::vector<Adapter::Queue> & queues,std::vector<uint32_t> & queue_usage)84 VkDevice create_logical_device(
85 const PhysicalDevice& physical_device,
86 const uint32_t num_queues_to_create,
87 std::vector<Adapter::Queue>& queues,
88 std::vector<uint32_t>& queue_usage) {
89 // Find compute queues up to the requested number of queues
90
91 std::vector<VkDeviceQueueCreateInfo> queue_create_infos;
92 queue_create_infos.reserve(num_queues_to_create);
93
94 std::vector<std::pair<uint32_t, uint32_t>> queues_to_get;
95 queues_to_get.reserve(num_queues_to_create);
96
97 uint32_t remaining_queues = num_queues_to_create;
98 for (uint32_t family_i = 0; family_i < physical_device.queue_families.size();
99 ++family_i) {
100 const VkQueueFamilyProperties& queue_properties =
101 physical_device.queue_families.at(family_i);
102 // Check if this family has compute capability
103 if (queue_properties.queueFlags & VK_QUEUE_COMPUTE_BIT) {
104 const uint32_t queues_to_init =
105 std::min(remaining_queues, queue_properties.queueCount);
106
107 const std::vector<float> queue_priorities(queues_to_init, 1.0f);
108 queue_create_infos.push_back({
109 VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, // sType
110 nullptr, // pNext
111 0u, // flags
112 family_i, // queueFamilyIndex
113 queues_to_init, // queueCount
114 queue_priorities.data(), // pQueuePriorities
115 });
116
117 for (size_t queue_i = 0; queue_i < queues_to_init; ++queue_i) {
118 // Use this to get the queue handle once device is created
119 queues_to_get.emplace_back(family_i, queue_i);
120 }
121 remaining_queues -= queues_to_init;
122 }
123 if (remaining_queues == 0) {
124 break;
125 }
126 }
127
128 queues.reserve(queues_to_get.size());
129 queue_usage.reserve(queues_to_get.size());
130
131 // Create the VkDevice
132
133 std::vector<const char*> requested_device_extensions{
134 #ifdef VK_KHR_portability_subset
135 VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
136 #endif /* VK_KHR_portability_subset */
137 };
138
139 std::vector<const char*> enabled_device_extensions;
140 find_requested_device_extensions(
141 physical_device.handle,
142 enabled_device_extensions,
143 requested_device_extensions);
144
145 const VkDeviceCreateInfo device_create_info{
146 VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
147 nullptr, // pNext
148 0u, // flags
149 static_cast<uint32_t>(queue_create_infos.size()), // queueCreateInfoCount
150 queue_create_infos.data(), // pQueueCreateInfos
151 0u, // enabledLayerCount
152 nullptr, // ppEnabledLayerNames
153 static_cast<uint32_t>(
154 enabled_device_extensions.size()), // enabledExtensionCount
155 enabled_device_extensions.data(), // ppEnabledExtensionNames
156 nullptr, // pEnabledFeatures
157 };
158
159 VkDevice handle = nullptr;
160 VK_CHECK(vkCreateDevice(
161 physical_device.handle, &device_create_info, nullptr, &handle));
162
163 #ifdef USE_VULKAN_VOLK
164 volkLoadDevice(handle);
165 #endif /* USE_VULKAN_VOLK */
166
167 // Obtain handles for the created queues and initialize queue usage heuristic
168
169 for (const std::pair<uint32_t, uint32_t>& queue_idx : queues_to_get) {
170 VkQueue queue_handle = VK_NULL_HANDLE;
171 VkQueueFlags flags =
172 physical_device.queue_families.at(queue_idx.first).queueFlags;
173 vkGetDeviceQueue(handle, queue_idx.first, queue_idx.second, &queue_handle);
174 queues.push_back({queue_idx.first, queue_idx.second, flags, queue_handle});
175 // Initial usage value
176 queue_usage.push_back(0);
177 }
178
179 return handle;
180 }
181
182 // Print utils
183
get_device_type_str(const VkPhysicalDeviceType type)184 std::string get_device_type_str(const VkPhysicalDeviceType type) {
185 switch (type) {
186 case VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU:
187 return "INTEGRATED_GPU";
188 case VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU:
189 return "DISCRETE_GPU";
190 case VK_PHYSICAL_DEVICE_TYPE_VIRTUAL_GPU:
191 return "VIRTUAL_GPU";
192 case VK_PHYSICAL_DEVICE_TYPE_CPU:
193 return "CPU";
194 default:
195 return "UNKNOWN";
196 }
197 }
198
get_memory_properties_str(const VkMemoryPropertyFlags flags)199 std::string get_memory_properties_str(const VkMemoryPropertyFlags flags) {
200 std::bitset<10> values(flags);
201 std::stringstream ss("|");
202 if (values[0]) {
203 ss << " DEVICE_LOCAL |";
204 }
205 if (values[1]) {
206 ss << " HOST_VISIBLE |";
207 }
208 if (values[2]) {
209 ss << " HOST_COHERENT |";
210 }
211 if (values[3]) {
212 ss << " HOST_CACHED |";
213 }
214 if (values[4]) {
215 ss << " LAZILY_ALLOCATED |";
216 }
217
218 return ss.str();
219 }
220
get_queue_family_properties_str(const VkQueueFlags flags)221 std::string get_queue_family_properties_str(const VkQueueFlags flags) {
222 std::bitset<10> values(flags);
223 std::stringstream ss("|");
224 if (values[0]) {
225 ss << " GRAPHICS |";
226 }
227 if (values[1]) {
228 ss << " COMPUTE |";
229 }
230 if (values[2]) {
231 ss << " TRANSFER |";
232 }
233
234 return ss.str();
235 }
236
237 } // namespace
238
239 //
240 // DeviceHandle
241 //
242
DeviceHandle(VkDevice device)243 DeviceHandle::DeviceHandle(VkDevice device) : handle_(device) {}
244
DeviceHandle(DeviceHandle && other)245 DeviceHandle::DeviceHandle(DeviceHandle&& other) noexcept
246 : handle_(other.handle_) {
247 other.handle_ = VK_NULL_HANDLE;
248 }
249
~DeviceHandle()250 DeviceHandle::~DeviceHandle() {
251 if (VK_NULL_HANDLE == handle_) {
252 return;
253 }
254 vkDestroyDevice(handle_, nullptr);
255 }
256
257 //
258 // Adapter
259 //
260
Adapter(VkInstance instance,PhysicalDevice physical_device,const uint32_t num_queues)261 Adapter::Adapter(
262 VkInstance instance,
263 PhysicalDevice physical_device,
264 const uint32_t num_queues)
265 : queue_usage_mutex_{},
266 physical_device_(std::move(physical_device)),
267 queues_{},
268 queue_usage_{},
269 queue_mutexes_{},
270 instance_(instance),
271 device_(create_logical_device(
272 physical_device_,
273 num_queues,
274 queues_,
275 queue_usage_)),
276 shader_layout_cache_(device_.handle_),
277 shader_cache_(device_.handle_),
278 pipeline_layout_cache_(device_.handle_),
279 compute_pipeline_cache_(device_.handle_),
280 sampler_cache_(device_.handle_),
281 vma_(instance_, physical_device_.handle, device_.handle_) {}
282
request_queue()283 Adapter::Queue Adapter::request_queue() {
284 // Lock the mutex as multiple threads can request a queue at the same time
285 std::lock_guard<std::mutex> lock(queue_usage_mutex_);
286
287 uint32_t min_usage = UINT32_MAX;
288 uint32_t min_used_i = 0;
289 for (size_t i = 0; i < queues_.size(); ++i) {
290 if (queue_usage_[i] < min_usage) {
291 min_used_i = i;
292 min_usage = queue_usage_[i];
293 }
294 }
295 queue_usage_[min_used_i] += 1;
296
297 return queues_[min_used_i];
298 }
299
return_queue(Adapter::Queue & compute_queue)300 void Adapter::return_queue(Adapter::Queue& compute_queue) {
301 for (size_t i = 0; i < queues_.size(); ++i) {
302 if ((queues_[i].family_index == compute_queue.family_index) &&
303 (queues_[i].queue_index == compute_queue.queue_index)) {
304 std::lock_guard<std::mutex> lock(queue_usage_mutex_);
305 queue_usage_[i] -= 1;
306 break;
307 }
308 }
309 }
310
submit_cmd(const Adapter::Queue & device_queue,VkCommandBuffer cmd,VkFence fence)311 void Adapter::submit_cmd(
312 const Adapter::Queue& device_queue,
313 VkCommandBuffer cmd,
314 VkFence fence) {
315 const VkSubmitInfo submit_info{
316 VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
317 nullptr, // pNext
318 0u, // waitSemaphoreCount
319 nullptr, // pWaitSemaphores
320 nullptr, // pWaitDstStageMask
321 1u, // commandBufferCount
322 &cmd, // pCommandBuffers
323 0u, // signalSemaphoreCount
324 nullptr, // pSignalSemaphores
325 };
326
327 std::lock_guard<std::mutex> queue_lock(
328 queue_mutexes_[device_queue.queue_index % NUM_QUEUE_MUTEXES]);
329
330 VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence));
331 }
332
submit_cmds(const Adapter::Queue & device_queue,const std::vector<VkCommandBuffer> & cmds,VkFence fence)333 void Adapter::submit_cmds(
334 const Adapter::Queue& device_queue,
335 const std::vector<VkCommandBuffer>& cmds,
336 VkFence fence) {
337 const VkSubmitInfo submit_info{
338 VK_STRUCTURE_TYPE_SUBMIT_INFO, // sType
339 nullptr, // pNext
340 0u, // waitSemaphoreCount
341 nullptr, // pWaitSemaphores
342 nullptr, // pWaitDstStageMask
343 utils::safe_downcast<uint32_t>(cmds.size()), // commandBufferCount
344 cmds.data(), // pCommandBuffers
345 0u, // signalSemaphoreCount
346 nullptr, // pSignalSemaphores
347 };
348
349 VK_CHECK(vkQueueSubmit(device_queue.handle, 1u, &submit_info, fence));
350 }
351
stringize() const352 std::string Adapter::stringize() const {
353 std::stringstream ss;
354
355 VkPhysicalDeviceProperties properties = physical_device_.properties;
356 uint32_t v_major = VK_VERSION_MAJOR(properties.apiVersion);
357 uint32_t v_minor = VK_VERSION_MINOR(properties.apiVersion);
358 std::string device_type = get_device_type_str(properties.deviceType);
359 VkPhysicalDeviceLimits limits = properties.limits;
360
361 ss << "{" << std::endl;
362 ss << " Physical Device Info {" << std::endl;
363 ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
364 ss << " driverversion: " << properties.driverVersion << std::endl;
365 ss << " deviceType: " << device_type << std::endl;
366 ss << " deviceName: " << properties.deviceName << std::endl;
367
368 #define PRINT_LIMIT_PROP(name) \
369 ss << " " << std::left << std::setw(36) << #name << limits.name \
370 << std::endl;
371
372 #define PRINT_LIMIT_PROP_VEC3(name) \
373 ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
374 << "," << limits.name[1] << "," << limits.name[2] << std::endl;
375
376 ss << " Physical Device Limits {" << std::endl;
377 PRINT_LIMIT_PROP(maxImageDimension1D);
378 PRINT_LIMIT_PROP(maxImageDimension2D);
379 PRINT_LIMIT_PROP(maxImageDimension3D);
380 PRINT_LIMIT_PROP(maxTexelBufferElements);
381 PRINT_LIMIT_PROP(maxPushConstantsSize);
382 PRINT_LIMIT_PROP(maxMemoryAllocationCount);
383 PRINT_LIMIT_PROP(maxSamplerAllocationCount);
384 PRINT_LIMIT_PROP(maxComputeSharedMemorySize);
385 PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount);
386 PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations);
387 PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize);
388 ss << " }" << std::endl;
389 ss << " }" << std::endl;
390 ;
391
392 const VkPhysicalDeviceMemoryProperties& mem_props =
393 physical_device_.memory_properties;
394
395 ss << " Memory Info {" << std::endl;
396 ss << " Memory Types [" << std::endl;
397 for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) {
398 ss << " "
399 << " [Heap " << mem_props.memoryTypes[i].heapIndex << "] "
400 << get_memory_properties_str(mem_props.memoryTypes[i].propertyFlags)
401 << std::endl;
402 }
403 ss << " ]" << std::endl;
404 ss << " Memory Heaps [" << std::endl;
405 for (size_t i = 0; i < mem_props.memoryHeapCount; ++i) {
406 ss << " " << mem_props.memoryHeaps[i].size << std::endl;
407 }
408 ss << " ]" << std::endl;
409 ss << " }" << std::endl;
410
411 ss << " Queue Families {" << std::endl;
412 for (const VkQueueFamilyProperties& queue_family_props :
413 physical_device_.queue_families) {
414 ss << " (" << queue_family_props.queueCount << " Queues) "
415 << get_queue_family_properties_str(queue_family_props.queueFlags)
416 << std::endl;
417 }
418 ss << " }" << std::endl;
419 ss << " VkDevice: " << device_.handle_ << std::endl;
420 ss << " Compute Queues [" << std::endl;
421 for (const Adapter::Queue& compute_queue : queues_) {
422 ss << " Family " << compute_queue.family_index << ", Queue "
423 << compute_queue.queue_index << ": " << compute_queue.handle
424 << std::endl;
425 ;
426 }
427 ss << " ]" << std::endl;
428 ss << "}";
429
430 return ss.str();
431 }
432
operator <<(std::ostream & os,const Adapter & adapter)433 std::ostream& operator<<(std::ostream& os, const Adapter& adapter) {
434 os << adapter.stringize() << std::endl;
435 return os;
436 }
437
438 } // namespace api
439 } // namespace vulkan
440 } // namespace native
441 } // namespace at
442