xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/ATenCUDAGeneral.h>
4 #include <cuda.h>
5 #include <nvrtc.h>
6 
7 namespace at { namespace cuda {
8 
9 
10 // NOTE [ USE OF NVRTC AND DRIVER API ]
11 //
12 // ATen does not directly link to either libnvrtc or libcuda because they
13 // require libcuda to be installed, yet we want our GPU build to work on CPU
14 // machines as long as CUDA is not initialized.
15 //
16 // Normal CUDA code in torch uses the cuda runtime libraries which can be
17 // installed even if the driver is not installed, but sometimes we specifically
18 // need to use the driver API (e.g., to load JIT compiled code).
19 // To accomplish this, we lazily link libcaffe2_nvrtc which provides a struct
20 // at::cuda::NVRTC that contains function pointers to all of the apis we need.
21 //
22 // IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY.
23 // INSTEAD USE, e.g.
24 //   detail::getCUDAHooks().nvrtc().cuLoadModule(...)
25 // or
26 //   globalContext().getNVRTC().cuLoadModule(...)
27 //
28 // If a function is missing add it to the list in ATen/cuda/nvrtc_stub/ATenNVRTC.h
29 // and edit ATen/cuda/detail/LazyNVRTC.cpp accordingly (e.g., via one of the stub
30 // macros).
31 
32 #if !defined(USE_ROCM)
33 
34 #define AT_FORALL_NVRTC_BASE(_)                  \
35   _(nvrtcVersion)                                \
36   _(nvrtcAddNameExpression)                      \
37   _(nvrtcCreateProgram)                          \
38   _(nvrtcDestroyProgram)                         \
39   _(nvrtcGetPTXSize)                             \
40   _(nvrtcGetPTX)                                 \
41   _(nvrtcCompileProgram)                         \
42   _(nvrtcGetErrorString)                         \
43   _(nvrtcGetProgramLogSize)                      \
44   _(nvrtcGetProgramLog)                          \
45   _(nvrtcGetLoweredName)                         \
46   _(cuModuleLoadData)                            \
47   _(cuModuleLoadDataEx)                          \
48   _(cuModuleGetFunction)                         \
49   _(cuOccupancyMaxActiveBlocksPerMultiprocessor) \
50   _(cuGetErrorString)                            \
51   _(cuLaunchKernel)                              \
52   _(cuLaunchCooperativeKernel)                   \
53   _(cuCtxGetCurrent)                             \
54   _(cuCtxSetCurrent)                             \
55   _(cuModuleUnload)                              \
56   _(cuDevicePrimaryCtxGetState)                  \
57   _(cuDevicePrimaryCtxRetain)                    \
58   _(cuLinkCreate)                                \
59   _(cuLinkAddData)                               \
60   _(cuLinkComplete)                              \
61   _(cuFuncSetAttribute)                          \
62   _(cuFuncGetAttribute)                          \
63 
64 #if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
65 #define AT_FORALL_NVRTC_EXTENDED(_)              \
66   AT_FORALL_NVRTC_BASE(_)                        \
67   _(cuTensorMapEncodeTiled)
68 #else
69 #define AT_FORALL_NVRTC_EXTENDED(_)              \
70   AT_FORALL_NVRTC_BASE(_)
71 #endif
72 
73 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
74 #define AT_FORALL_NVRTC(_) \
75   AT_FORALL_NVRTC_EXTENDED(_)  \
76   _(nvrtcGetCUBINSize)     \
77   _(nvrtcGetCUBIN)
78 #else
79 #define AT_FORALL_NVRTC(_) \
80   AT_FORALL_NVRTC_EXTENDED(_)
81 #endif
82 
83 #else
84 
85 // NOTE [ ATen NVRTC Stub and HIP ]
86 //
87 // ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both
88 // NVRTC and driver APIs. While the former is not yet supported for HIP, the
89 // later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext()
90 // used by tensor.pin_memory()).
91 //
92 // The macro below strips out certain unsupported operations on HIP from the full
93 // list above.
94 //
95 // HIP doesn't have
96 //   cuGetErrorString  (maps to non-functional hipGetErrorString___)
97 //
98 // HIP from ROCm 3.5 on renamed hipOccupancyMaxActiveBlocksPerMultiprocessor
99 // to hipModuleOccupancyMaxActiveBlocksPerMultiprocessor.
100 #if TORCH_HIP_VERSION < 305
101 #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR hipOccupancyMaxActiveBlocksPerMultiprocessor
102 #else
103 #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR cuOccupancyMaxActiveBlocksPerMultiprocessor
104 #endif
105 
106 #define AT_FORALL_NVRTC(_)                        \
107   _(nvrtcVersion)                                 \
108   _(nvrtcCreateProgram)                           \
109   _(nvrtcAddNameExpression)                       \
110   _(nvrtcDestroyProgram)                          \
111   _(nvrtcGetPTXSize)                              \
112   _(nvrtcGetPTX)                                  \
113   _(cuModuleLoadData)                             \
114   _(cuModuleGetFunction)                          \
115   _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \
116   _(nvrtcGetErrorString)                          \
117   _(nvrtcGetProgramLogSize)                       \
118   _(nvrtcGetProgramLog)                           \
119   _(cuLaunchKernel)                               \
120   _(nvrtcCompileProgram)                          \
121   _(cuCtxGetCurrent)                              \
122   _(nvrtcGetLoweredName)                          \
123   _(cuModuleUnload)                               \
124   _(cuDevicePrimaryCtxGetState)
125 
126 #endif
127 
128 extern "C" typedef struct NVRTC {
129 #define CREATE_MEMBER(name) decltype(&name) name;
130   AT_FORALL_NVRTC(CREATE_MEMBER)
131 #undef CREATE_MEMBER
132 } NVRTC;
133 
134 extern "C" TORCH_CUDA_CPP_API NVRTC* load_nvrtc();
135 }} // at::cuda
136