xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/orchestration/vulkan.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/profiler/orchestration/vulkan.h>
2 
3 #include <utility>
4 
5 namespace torch {
6 namespace profiler {
7 namespace impl {
8 namespace vulkan {
9 namespace {
10 
11 GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns_fn;
12 
13 } // namespace
14 
registerGetShaderNameAndDurationNs(GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns)15 void registerGetShaderNameAndDurationNs(
16     GetShaderNameAndDurationNsFn get_shader_name_and_duration_ns) {
17   get_shader_name_and_duration_ns_fn =
18       std::move(get_shader_name_and_duration_ns);
19 }
20 
deregisterGetShaderNameAndDurationNs()21 void deregisterGetShaderNameAndDurationNs() {
22   get_shader_name_and_duration_ns_fn = nullptr;
23 }
24 
getShaderNameAndDurationNs(const vulkan_id_t & vulkan_id)25 std::tuple<std::string, uint64_t> getShaderNameAndDurationNs(
26     const vulkan_id_t& vulkan_id) {
27   /*
28     We don't need to worry about a race condition with
29     deregisterGetShaderNameAndDurationNs here currently because
30     deregisterGetShaderNameAndDurationNs is only called within the destructor
31     of QueryPool, which would only be called after we're done calling
32     getShaderNameAndDurationNs
33   */
34   TORCH_CHECK(
35       get_shader_name_and_duration_ns_fn != nullptr,
36       "Attempting to get shader duration in ",
37       "torch::profiler::impl::vulkan::getShaderNameAndDurationNs, but "
38       "get_shader_duration_fn is unregistered. Use "
39       "torch::profiler::impl::vulkan::registerGetShaderNameAndDurationNs to register "
40       "it first");
41   return get_shader_name_and_duration_ns_fn(vulkan_id.value_of());
42 }
43 
44 } // namespace vulkan
45 } // namespace impl
46 } // namespace profiler
47 } // namespace torch
48