xref: /aosp_15_r20/external/pytorch/c10/cuda/driver_api.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker #include <cuda.h>
3*da0073e9SAndroid Build Coastguard Worker #define NVML_NO_UNVERSIONED_FUNC_DEFS
4*da0073e9SAndroid Build Coastguard Worker #include <nvml.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #define C10_CUDA_DRIVER_CHECK(EXPR)                                        \
7*da0073e9SAndroid Build Coastguard Worker   do {                                                                     \
8*da0073e9SAndroid Build Coastguard Worker     CUresult __err = EXPR;                                                 \
9*da0073e9SAndroid Build Coastguard Worker     if (__err != CUDA_SUCCESS) {                                           \
10*da0073e9SAndroid Build Coastguard Worker       const char* err_str;                                                 \
11*da0073e9SAndroid Build Coastguard Worker       CUresult get_error_str_err C10_UNUSED =                              \
12*da0073e9SAndroid Build Coastguard Worker           c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
13*da0073e9SAndroid Build Coastguard Worker       if (get_error_str_err != CUDA_SUCCESS) {                             \
14*da0073e9SAndroid Build Coastguard Worker         AT_ERROR("CUDA driver error: unknown error");                      \
15*da0073e9SAndroid Build Coastguard Worker       } else {                                                             \
16*da0073e9SAndroid Build Coastguard Worker         AT_ERROR("CUDA driver error: ", err_str);                          \
17*da0073e9SAndroid Build Coastguard Worker       }                                                                    \
18*da0073e9SAndroid Build Coastguard Worker     }                                                                      \
19*da0073e9SAndroid Build Coastguard Worker   } while (0)
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker #define C10_LIBCUDA_DRIVER_API(_)   \
22*da0073e9SAndroid Build Coastguard Worker   _(cuDeviceGetAttribute)           \
23*da0073e9SAndroid Build Coastguard Worker   _(cuMemAddressReserve)            \
24*da0073e9SAndroid Build Coastguard Worker   _(cuMemRelease)                   \
25*da0073e9SAndroid Build Coastguard Worker   _(cuMemMap)                       \
26*da0073e9SAndroid Build Coastguard Worker   _(cuMemAddressFree)               \
27*da0073e9SAndroid Build Coastguard Worker   _(cuMemSetAccess)                 \
28*da0073e9SAndroid Build Coastguard Worker   _(cuMemUnmap)                     \
29*da0073e9SAndroid Build Coastguard Worker   _(cuMemCreate)                    \
30*da0073e9SAndroid Build Coastguard Worker   _(cuMemGetAllocationGranularity)  \
31*da0073e9SAndroid Build Coastguard Worker   _(cuMemExportToShareableHandle)   \
32*da0073e9SAndroid Build Coastguard Worker   _(cuMemImportFromShareableHandle) \
33*da0073e9SAndroid Build Coastguard Worker   _(cuGetErrorString)
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
36*da0073e9SAndroid Build Coastguard Worker #define C10_LIBCUDA_DRIVER_API_12030(_) \
37*da0073e9SAndroid Build Coastguard Worker   _(cuMulticastAddDevice)               \
38*da0073e9SAndroid Build Coastguard Worker   _(cuMulticastBindMem)                 \
39*da0073e9SAndroid Build Coastguard Worker   _(cuMulticastCreate)
40*da0073e9SAndroid Build Coastguard Worker #else
41*da0073e9SAndroid Build Coastguard Worker #define C10_LIBCUDA_DRIVER_API_12030(_)
42*da0073e9SAndroid Build Coastguard Worker #endif
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker #define C10_NVML_DRIVER_API(_)           \
45*da0073e9SAndroid Build Coastguard Worker   _(nvmlInit_v2)                         \
46*da0073e9SAndroid Build Coastguard Worker   _(nvmlDeviceGetHandleByPciBusId_v2)    \
47*da0073e9SAndroid Build Coastguard Worker   _(nvmlDeviceGetNvLinkRemoteDeviceType) \
48*da0073e9SAndroid Build Coastguard Worker   _(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
49*da0073e9SAndroid Build Coastguard Worker   _(nvmlDeviceGetComputeRunningProcesses)
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker struct DriverAPI {
54*da0073e9SAndroid Build Coastguard Worker #define CREATE_MEMBER(name) decltype(&name) name##_;
55*da0073e9SAndroid Build Coastguard Worker   C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
56*da0073e9SAndroid Build Coastguard Worker   C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
57*da0073e9SAndroid Build Coastguard Worker   C10_NVML_DRIVER_API(CREATE_MEMBER)
58*da0073e9SAndroid Build Coastguard Worker #undef CREATE_MEMBER
59*da0073e9SAndroid Build Coastguard Worker   static DriverAPI* get();
60*da0073e9SAndroid Build Coastguard Worker   static void* get_nvml_handle();
61*da0073e9SAndroid Build Coastguard Worker };
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
64