1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName 12 13 #include <executorch/backends/vulkan/runtime/vk_api/vk_api.h> 14 15 #include <executorch/backends/vulkan/runtime/vk_api/Adapter.h> 16 17 #include <functional> 18 #include <memory> 19 20 namespace vkcompute { 21 namespace vkapi { 22 23 // 24 // A Vulkan Runtime initializes a Vulkan instance and decouples the concept of 25 // Vulkan instance initialization from initialization of, and subsequent 26 // interactions with, Vulkan [physical and logical] devices as a precursor to 27 // multi-GPU support. The Vulkan Runtime can be queried for available Adapters 28 // (i.e. physical devices) in the system which in turn can be used for creation 29 // of a Vulkan Context (i.e. logical devices). All Vulkan tensors in PyTorch 30 // are associated with a Context to make tensor <-> device affinity explicit. 31 // 32 33 enum AdapterSelector { 34 First, 35 }; 36 37 struct RuntimeConfig final { 38 bool enable_validation_messages; 39 bool init_default_device; 40 AdapterSelector default_selector; 41 uint32_t num_requested_queues; 42 std::string cache_data_path; 43 }; 44 45 class Runtime final { 46 public: 47 explicit Runtime(const RuntimeConfig); 48 49 // Do not allow copying. There should be only one global instance of this 50 // class. 51 Runtime(const Runtime&) = delete; 52 Runtime& operator=(const Runtime&) = delete; 53 54 Runtime(Runtime&&) = delete; 55 Runtime& operator=(Runtime&&) = delete; 56 57 ~Runtime(); 58 59 using DeviceMapping = std::pair<PhysicalDevice, int32_t>; 60 using AdapterPtr = std::unique_ptr<Adapter>; 61 62 private: 63 RuntimeConfig config_; 64 65 VkInstance instance_; 66 67 std::vector<DeviceMapping> device_mappings_; 68 std::vector<AdapterPtr> adapters_; 69 uint32_t default_adapter_i_; 70 71 VkDebugReportCallbackEXT debug_report_callback_; 72 73 public: instance()74 inline VkInstance instance() const { 75 return instance_; 76 } 77 get_adapter_p()78 inline Adapter* get_adapter_p() { 79 VK_CHECK_COND( 80 default_adapter_i_ >= 0 && default_adapter_i_ < adapters_.size(), 81 "Pytorch Vulkan Runtime: Default device adapter is not set correctly!"); 82 return adapters_[default_adapter_i_].get(); 83 } 84 get_adapter_p(uint32_t i)85 inline Adapter* get_adapter_p(uint32_t i) { 86 VK_CHECK_COND( 87 i >= 0 && i < adapters_.size(), 88 "Pytorch Vulkan Runtime: Adapter at index ", 89 i, 90 " is not available!"); 91 return adapters_[i].get(); 92 } 93 default_adapter_i()94 inline uint32_t default_adapter_i() const { 95 return default_adapter_i_; 96 } 97 98 using Selector = 99 std::function<uint32_t(const std::vector<Runtime::DeviceMapping>&)>; 100 uint32_t create_adapter(const Selector&); 101 }; 102 103 // The global runtime is retrieved using this function, where it is declared as 104 // a static local variable. 105 Runtime* runtime(); 106 107 } // namespace vkapi 108 } // namespace vkcompute 109