xref: /aosp_15_r20/external/pytorch/c10/cuda/driver_api.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2 #include <c10/cuda/driver_api.h>
3 #include <c10/util/CallOnce.h>
4 #include <c10/util/Exception.h>
5 #include <dlfcn.h>
6 
7 namespace c10::cuda {
8 
9 namespace {
10 
create_driver_api()11 DriverAPI create_driver_api() {
12   void* handle_0 = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_NOLOAD);
13   TORCH_CHECK(handle_0, "Can't open libcuda.so.1: ", dlerror());
14   void* handle_1 = DriverAPI::get_nvml_handle();
15   DriverAPI r{};
16 
17 #define LOOKUP_LIBCUDA_ENTRY(name)                       \
18   r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
19   TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror())
20   C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
21 #undef LOOKUP_LIBCUDA_ENTRY
22 
23 #define LOOKUP_LIBCUDA_ENTRY(name)                       \
24   r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
25   dlerror();
26   C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
27 #undef LOOKUP_LIBCUDA_ENTRY
28 
29   if (handle_1) {
30 #define LOOKUP_NVML_ENTRY(name)                          \
31   r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
32   TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror())
33     C10_NVML_DRIVER_API(LOOKUP_NVML_ENTRY)
34 #undef LOOKUP_NVML_ENTRY
35   }
36   return r;
37 }
38 } // namespace
39 
get_nvml_handle()40 void* DriverAPI::get_nvml_handle() {
41   static void* nvml_hanle = dlopen("libnvidia-ml.so.1", RTLD_LAZY);
42   return nvml_hanle;
43 }
44 
get()45 C10_EXPORT DriverAPI* DriverAPI::get() {
46   static DriverAPI singleton = create_driver_api();
47   return &singleton;
48 }
49 
50 } // namespace c10::cuda
51 
52 #endif
53