1 #pragma once 2 3 #include <cublas_v2.h> 4 #include <cusparse.h> 5 #include <c10/macros/Export.h> 6 7 #ifdef CUDART_VERSION 8 #include <cusolver_common.h> 9 #endif 10 11 #if defined(USE_CUDSS) 12 #include <cudss.h> 13 #endif 14 15 #include <ATen/Context.h> 16 #include <c10/util/Exception.h> 17 #include <c10/cuda/CUDAException.h> 18 19 20 namespace c10 { 21 22 class CuDNNError : public c10::Error { 23 using Error::Error; 24 }; 25 26 } // namespace c10 27 28 #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ 29 do { \ 30 auto error_object = EXPR; \ 31 if (!error_object.is_good()) { \ 32 TORCH_CHECK_WITH(CuDNNError, false, \ 33 "cuDNN Frontend error: ", error_object.get_message()); \ 34 } \ 35 } while (0) \ 36 37 #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) 38 39 // See Note [CHECK macro] 40 #define AT_CUDNN_CHECK(EXPR, ...) \ 41 do { \ 42 cudnnStatus_t status = EXPR; \ 43 if (status != CUDNN_STATUS_SUCCESS) { \ 44 if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ 45 TORCH_CHECK_WITH(CuDNNError, false, \ 46 "cuDNN error: ", \ 47 cudnnGetErrorString(status), \ 48 ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ 49 } else { \ 50 TORCH_CHECK_WITH(CuDNNError, false, \ 51 "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ 52 } \ 53 } \ 54 } while (0) 55 56 namespace at::cuda::blas { 57 C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); 58 } // namespace at::cuda::blas 59 60 #define TORCH_CUDABLAS_CHECK(EXPR) \ 61 do { \ 62 cublasStatus_t __err = EXPR; \ 63 TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ 64 "CUDA error: ", \ 65 at::cuda::blas::_cublasGetErrorEnum(__err), \ 66 " when calling `" #EXPR "`"); \ 67 } while (0) 68 69 const char *cusparseGetErrorString(cusparseStatus_t status); 70 71 #define TORCH_CUDASPARSE_CHECK(EXPR) \ 72 do { \ 73 cusparseStatus_t __err = EXPR; \ 74 TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ 75 "CUDA error: ", \ 76 cusparseGetErrorString(__err), \ 77 " when calling `" #EXPR "`"); \ 78 } while (0) 79 80 #if defined(USE_CUDSS) 81 namespace at::cuda::cudss { 82 C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error); 83 } // namespace at::cuda::solver 84 85 #define TORCH_CUDSS_CHECK(EXPR) \ 86 do { \ 87 cudssStatus_t __err = EXPR; \ 88 if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \ 89 TORCH_CHECK_LINALG( \ 90 false, \ 91 "cudss error: ", \ 92 at::cuda::cudss::cudssGetErrorMessage(__err), \ 93 ", when calling `" #EXPR "`", \ 94 ". This error may appear if the input matrix contains NaN. ");\ 95 } else { \ 96 TORCH_CHECK( \ 97 __err == CUDSS_STATUS_SUCCESS, \ 98 "cudss error: ", \ 99 at::cuda::cudss::cudssGetErrorMessage(__err), \ 100 ", when calling `" #EXPR "`. "); \ 101 } \ 102 } while (0) 103 #else 104 #define TORCH_CUDSS_CHECK(EXPR) EXPR 105 #endif 106 107 // cusolver related headers are only supported on cuda now 108 #ifdef CUDART_VERSION 109 110 namespace at::cuda::solver { 111 C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); 112 113 constexpr const char* _cusolver_backend_suggestion = \ 114 "If you keep seeing this error, you may use " \ 115 "`torch.backends.cuda.preferred_linalg_library()` to try " \ 116 "linear algebra operators with other supported backends. " \ 117 "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; 118 119 } // namespace at::cuda::solver 120 121 // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. 122 // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. 123 #define TORCH_CUSOLVER_CHECK(EXPR) \ 124 do { \ 125 cusolverStatus_t __err = EXPR; \ 126 if ((CUDA_VERSION < 11500 && \ 127 __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ 128 (CUDA_VERSION >= 11500 && \ 129 __err == CUSOLVER_STATUS_INVALID_VALUE)) { \ 130 TORCH_CHECK_LINALG( \ 131 false, \ 132 "cusolver error: ", \ 133 at::cuda::solver::cusolverGetErrorMessage(__err), \ 134 ", when calling `" #EXPR "`", \ 135 ". This error may appear if the input matrix contains NaN. ", \ 136 at::cuda::solver::_cusolver_backend_suggestion); \ 137 } else { \ 138 TORCH_CHECK( \ 139 __err == CUSOLVER_STATUS_SUCCESS, \ 140 "cusolver error: ", \ 141 at::cuda::solver::cusolverGetErrorMessage(__err), \ 142 ", when calling `" #EXPR "`. ", \ 143 at::cuda::solver::_cusolver_backend_suggestion); \ 144 } \ 145 } while (0) 146 147 #else 148 #define TORCH_CUSOLVER_CHECK(EXPR) EXPR 149 #endif 150 151 #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) 152 153 // For CUDA Driver API 154 // 155 // This is here instead of in c10 because NVRTC is loaded dynamically via a stub 156 // in ATen, and we need to use its nvrtcGetErrorString. 157 // See NOTE [ USE OF NVRTC AND DRIVER API ]. 158 #if !defined(USE_ROCM) 159 160 #define AT_CUDA_DRIVER_CHECK(EXPR) \ 161 do { \ 162 CUresult __err = EXPR; \ 163 if (__err != CUDA_SUCCESS) { \ 164 const char* err_str; \ 165 CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ 166 if (get_error_str_err != CUDA_SUCCESS) { \ 167 AT_ERROR("CUDA driver error: unknown error"); \ 168 } else { \ 169 AT_ERROR("CUDA driver error: ", err_str); \ 170 } \ 171 } \ 172 } while (0) 173 174 #else 175 176 #define AT_CUDA_DRIVER_CHECK(EXPR) \ 177 do { \ 178 CUresult __err = EXPR; \ 179 if (__err != CUDA_SUCCESS) { \ 180 AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \ 181 } \ 182 } while (0) 183 184 #endif 185 186 // For CUDA NVRTC 187 // 188 // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, 189 // incorrectly produces the error string "NVRTC unknown error." 190 // The following maps it correctly. 191 // 192 // This is here instead of in c10 because NVRTC is loaded dynamically via a stub 193 // in ATen, and we need to use its nvrtcGetErrorString. 194 // See NOTE [ USE OF NVRTC AND DRIVER API ]. 195 #define AT_CUDA_NVRTC_CHECK(EXPR) \ 196 do { \ 197 nvrtcResult __err = EXPR; \ 198 if (__err != NVRTC_SUCCESS) { \ 199 if (static_cast<int>(__err) != 7) { \ 200 AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ 201 } else { \ 202 AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ 203 } \ 204 } \ 205 } while (0) 206