1 #include <torch/csrc/profiler/orchestration/vulkan.h> 2 3 #include <utility> 4 5 namespace torch { 6 namespace profiler { 7 namespace impl { 8 namespace vulkan { 9 namespace { 10 11 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns_fn; 12 13 } // namespace 14 registerGetShaderNameAndDurationNs(GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns)15void registerGetShaderNameAndDurationNs( 16 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns) { 17 get_shader_name_and_duration_ns_fn = 18 std::move(get_shader_name_and_duration_ns); 19 } 20 deregisterGetShaderNameAndDurationNs()21void deregisterGetShaderNameAndDurationNs() { 22 get_shader_name_and_duration_ns_fn = nullptr; 23 } 24 getShaderNameAndDurationNs(const vulkan_id_t & vulkan_id)25std::tuple<std::string, uint64_t> getShaderNameAndDurationNs( 26 const vulkan_id_t& vulkan_id) { 27 /* 28 We don't need to worry about a race condition with 29 deregisterGetShaderNameAndDurationNs here currently because 30 deregisterGetShaderNameAndDurationNs is only called within the destructor 31 of QueryPool, which would only be called after we're done calling 32 getShaderNameAndDurationNs 33 */ 34 TORCH_CHECK( 35 get_shader_name_and_duration_ns_fn != nullptr, 36 "Attempting to get shader duration in ", 37 "torch::profiler::impl::vulkan::getShaderNameAndDurationNs, but " 38 "get_shader_duration_fn is unregistered. Use " 39 "torch::profiler::impl::vulkan::registerGetShaderNameAndDurationNs to register " 40 "it first"); 41 return get_shader_name_and_duration_ns_fn(vulkan_id.value_of()); 42 } 43 44 } // namespace vulkan 45 } // namespace impl 46 } // namespace profiler 47 } // namespace torch 48