1 #pragma once 2 3 #include <ATen/cuda/ATenCUDAGeneral.h> 4 #include <cuda.h> 5 #include <nvrtc.h> 6 7 namespace at { namespace cuda { 8 9 10 // NOTE [ USE OF NVRTC AND DRIVER API ] 11 // 12 // ATen does not directly link to either libnvrtc or libcuda because they 13 // require libcuda to be installed, yet we want our GPU build to work on CPU 14 // machines as long as CUDA is not initialized. 15 // 16 // Normal CUDA code in torch uses the cuda runtime libraries which can be 17 // installed even if the driver is not installed, but sometimes we specifically 18 // need to use the driver API (e.g., to load JIT compiled code). 19 // To accomplish this, we lazily link libcaffe2_nvrtc which provides a struct 20 // at::cuda::NVRTC that contains function pointers to all of the apis we need. 21 // 22 // IT IS AN ERROR TO TRY TO CALL ANY nvrtc* or cu* FUNCTION DIRECTLY. 23 // INSTEAD USE, e.g. 24 // detail::getCUDAHooks().nvrtc().cuLoadModule(...) 25 // or 26 // globalContext().getNVRTC().cuLoadModule(...) 27 // 28 // If a function is missing add it to the list in ATen/cuda/nvrtc_stub/ATenNVRTC.h 29 // and edit ATen/cuda/detail/LazyNVRTC.cpp accordingly (e.g., via one of the stub 30 // macros). 31 32 #if !defined(USE_ROCM) 33 34 #define AT_FORALL_NVRTC_BASE(_) \ 35 _(nvrtcVersion) \ 36 _(nvrtcAddNameExpression) \ 37 _(nvrtcCreateProgram) \ 38 _(nvrtcDestroyProgram) \ 39 _(nvrtcGetPTXSize) \ 40 _(nvrtcGetPTX) \ 41 _(nvrtcCompileProgram) \ 42 _(nvrtcGetErrorString) \ 43 _(nvrtcGetProgramLogSize) \ 44 _(nvrtcGetProgramLog) \ 45 _(nvrtcGetLoweredName) \ 46 _(cuModuleLoadData) \ 47 _(cuModuleLoadDataEx) \ 48 _(cuModuleGetFunction) \ 49 _(cuOccupancyMaxActiveBlocksPerMultiprocessor) \ 50 _(cuGetErrorString) \ 51 _(cuLaunchKernel) \ 52 _(cuLaunchCooperativeKernel) \ 53 _(cuCtxGetCurrent) \ 54 _(cuCtxSetCurrent) \ 55 _(cuModuleUnload) \ 56 _(cuDevicePrimaryCtxGetState) \ 57 _(cuDevicePrimaryCtxRetain) \ 58 _(cuLinkCreate) \ 59 _(cuLinkAddData) \ 60 _(cuLinkComplete) \ 61 _(cuFuncSetAttribute) \ 62 _(cuFuncGetAttribute) \ 63 64 #if defined(CUDA_VERSION) && CUDA_VERSION >= 12000 65 #define AT_FORALL_NVRTC_EXTENDED(_) \ 66 AT_FORALL_NVRTC_BASE(_) \ 67 _(cuTensorMapEncodeTiled) 68 #else 69 #define AT_FORALL_NVRTC_EXTENDED(_) \ 70 AT_FORALL_NVRTC_BASE(_) 71 #endif 72 73 #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 74 #define AT_FORALL_NVRTC(_) \ 75 AT_FORALL_NVRTC_EXTENDED(_) \ 76 _(nvrtcGetCUBINSize) \ 77 _(nvrtcGetCUBIN) 78 #else 79 #define AT_FORALL_NVRTC(_) \ 80 AT_FORALL_NVRTC_EXTENDED(_) 81 #endif 82 83 #else 84 85 // NOTE [ ATen NVRTC Stub and HIP ] 86 // 87 // ATen's NVRTC stub library, caffe2_nvrtc, provides dynamic loading of both 88 // NVRTC and driver APIs. While the former is not yet supported for HIP, the 89 // later is supported and needed (e.g., in CUDAHooks::getDeviceWithPrimaryContext() 90 // used by tensor.pin_memory()). 91 // 92 // The macro below strips out certain unsupported operations on HIP from the full 93 // list above. 94 // 95 // HIP doesn't have 96 // cuGetErrorString (maps to non-functional hipGetErrorString___) 97 // 98 // HIP from ROCm 3.5 on renamed hipOccupancyMaxActiveBlocksPerMultiprocessor 99 // to hipModuleOccupancyMaxActiveBlocksPerMultiprocessor. 100 #if TORCH_HIP_VERSION < 305 101 #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR hipOccupancyMaxActiveBlocksPerMultiprocessor 102 #else 103 #define HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR cuOccupancyMaxActiveBlocksPerMultiprocessor 104 #endif 105 106 #define AT_FORALL_NVRTC(_) \ 107 _(nvrtcVersion) \ 108 _(nvrtcCreateProgram) \ 109 _(nvrtcAddNameExpression) \ 110 _(nvrtcDestroyProgram) \ 111 _(nvrtcGetPTXSize) \ 112 _(nvrtcGetPTX) \ 113 _(cuModuleLoadData) \ 114 _(cuModuleGetFunction) \ 115 _(HIPOCCUPANCYMAXACTIVEBLOCKSPERMULTIPROCESSOR) \ 116 _(nvrtcGetErrorString) \ 117 _(nvrtcGetProgramLogSize) \ 118 _(nvrtcGetProgramLog) \ 119 _(cuLaunchKernel) \ 120 _(nvrtcCompileProgram) \ 121 _(cuCtxGetCurrent) \ 122 _(nvrtcGetLoweredName) \ 123 _(cuModuleUnload) \ 124 _(cuDevicePrimaryCtxGetState) 125 126 #endif 127 128 extern "C" typedef struct NVRTC { 129 #define CREATE_MEMBER(name) decltype(&name) name; 130 AT_FORALL_NVRTC(CREATE_MEMBER) 131 #undef CREATE_MEMBER 132 } NVRTC; 133 134 extern "C" TORCH_CUDA_CPP_API NVRTC* load_nvrtc(); 135 }} // at::cuda 136