xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/Exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cublas_v2.h>
4 #include <cusparse.h>
5 #include <c10/macros/Export.h>
6 
7 #ifdef CUDART_VERSION
8 #include <cusolver_common.h>
9 #endif
10 
11 #if defined(USE_CUDSS)
12 #include <cudss.h>
13 #endif
14 
15 #include <ATen/Context.h>
16 #include <c10/util/Exception.h>
17 #include <c10/cuda/CUDAException.h>
18 
19 
20 namespace c10 {
21 
22 class CuDNNError : public c10::Error {
23   using Error::Error;
24 };
25 
26 }  // namespace c10
27 
28 #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...)                                                      \
29   do {                                                                                          \
30     auto error_object = EXPR;                                                                   \
31     if (!error_object.is_good()) {                                                              \
32       TORCH_CHECK_WITH(CuDNNError, false,                                                       \
33             "cuDNN Frontend error: ", error_object.get_message());                              \
34     }                                                                                           \
35   } while (0)                                                                                   \
36 
37 #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
38 
39 // See Note [CHECK macro]
40 #define AT_CUDNN_CHECK(EXPR, ...)                                                               \
41   do {                                                                                          \
42     cudnnStatus_t status = EXPR;                                                                \
43     if (status != CUDNN_STATUS_SUCCESS) {                                                       \
44       if (status == CUDNN_STATUS_NOT_SUPPORTED) {                                               \
45         TORCH_CHECK_WITH(CuDNNError, false,                                                     \
46             "cuDNN error: ",                                                                    \
47             cudnnGetErrorString(status),                                                        \
48             ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
49       } else {                                                                                  \
50         TORCH_CHECK_WITH(CuDNNError, false,                                                     \
51             "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__);                       \
52       }                                                                                         \
53     }                                                                                           \
54   } while (0)
55 
56 namespace at::cuda::blas {
57 C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
58 } // namespace at::cuda::blas
59 
60 #define TORCH_CUDABLAS_CHECK(EXPR)                              \
61   do {                                                          \
62     cublasStatus_t __err = EXPR;                                \
63     TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS,                 \
64                 "CUDA error: ",                                 \
65                 at::cuda::blas::_cublasGetErrorEnum(__err),     \
66                 " when calling `" #EXPR "`");                   \
67   } while (0)
68 
69 const char *cusparseGetErrorString(cusparseStatus_t status);
70 
71 #define TORCH_CUDASPARSE_CHECK(EXPR)                            \
72   do {                                                          \
73     cusparseStatus_t __err = EXPR;                              \
74     TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS,               \
75                 "CUDA error: ",                                 \
76                 cusparseGetErrorString(__err),                  \
77                 " when calling `" #EXPR "`");                   \
78   } while (0)
79 
80 #if defined(USE_CUDSS)
81 namespace at::cuda::cudss {
82 C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
83 } // namespace at::cuda::solver
84 
85 #define TORCH_CUDSS_CHECK(EXPR)                                         \
86   do {                                                                  \
87     cudssStatus_t __err = EXPR;                                         \
88     if (__err == CUDSS_STATUS_EXECUTION_FAILED) {                       \
89       TORCH_CHECK_LINALG(                                               \
90           false,                                                        \
91           "cudss error: ",                                              \
92           at::cuda::cudss::cudssGetErrorMessage(__err),                 \
93           ", when calling `" #EXPR "`",                                 \
94           ". This error may appear if the input matrix contains NaN. ");\
95     } else {                                                            \
96       TORCH_CHECK(                                                      \
97           __err == CUDSS_STATUS_SUCCESS,                                \
98           "cudss error: ",                                              \
99           at::cuda::cudss::cudssGetErrorMessage(__err),                 \
100           ", when calling `" #EXPR "`. ");                              \
101     }                                                                   \
102   } while (0)
103 #else
104 #define TORCH_CUDSS_CHECK(EXPR) EXPR
105 #endif
106 
107 // cusolver related headers are only supported on cuda now
108 #ifdef CUDART_VERSION
109 
110 namespace at::cuda::solver {
111 C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
112 
113 constexpr const char* _cusolver_backend_suggestion =            \
114   "If you keep seeing this error, you may use "                 \
115   "`torch.backends.cuda.preferred_linalg_library()` to try "    \
116   "linear algebra operators with other supported backends. "    \
117   "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
118 
119 } // namespace at::cuda::solver
120 
121 // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
122 // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
123 #define TORCH_CUSOLVER_CHECK(EXPR)                                      \
124   do {                                                                  \
125     cusolverStatus_t __err = EXPR;                                      \
126     if ((CUDA_VERSION < 11500 &&                                        \
127          __err == CUSOLVER_STATUS_EXECUTION_FAILED) ||                  \
128         (CUDA_VERSION >= 11500 &&                                       \
129          __err == CUSOLVER_STATUS_INVALID_VALUE)) {                     \
130       TORCH_CHECK_LINALG(                                               \
131           false,                                                        \
132           "cusolver error: ",                                           \
133           at::cuda::solver::cusolverGetErrorMessage(__err),             \
134           ", when calling `" #EXPR "`",                                 \
135           ". This error may appear if the input matrix contains NaN. ", \
136           at::cuda::solver::_cusolver_backend_suggestion);              \
137     } else {                                                            \
138       TORCH_CHECK(                                                      \
139           __err == CUSOLVER_STATUS_SUCCESS,                             \
140           "cusolver error: ",                                           \
141           at::cuda::solver::cusolverGetErrorMessage(__err),             \
142           ", when calling `" #EXPR "`. ",                               \
143           at::cuda::solver::_cusolver_backend_suggestion);              \
144     }                                                                   \
145   } while (0)
146 
147 #else
148 #define TORCH_CUSOLVER_CHECK(EXPR) EXPR
149 #endif
150 
151 #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
152 
153 // For CUDA Driver API
154 //
155 // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
156 // in ATen, and we need to use its nvrtcGetErrorString.
157 // See NOTE [ USE OF NVRTC AND DRIVER API ].
158 #if !defined(USE_ROCM)
159 
160 #define AT_CUDA_DRIVER_CHECK(EXPR)                                                                               \
161   do {                                                                                                           \
162     CUresult __err = EXPR;                                                                                       \
163     if (__err != CUDA_SUCCESS) {                                                                                 \
164       const char* err_str;                                                                                       \
165       CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str);  \
166       if (get_error_str_err != CUDA_SUCCESS) {                                                                   \
167         AT_ERROR("CUDA driver error: unknown error");                                                            \
168       } else {                                                                                                   \
169         AT_ERROR("CUDA driver error: ", err_str);                                                                \
170       }                                                                                                          \
171     }                                                                                                            \
172   } while (0)
173 
174 #else
175 
176 #define AT_CUDA_DRIVER_CHECK(EXPR)                                                \
177   do {                                                                            \
178     CUresult __err = EXPR;                                                        \
179     if (__err != CUDA_SUCCESS) {                                                  \
180       AT_ERROR("CUDA driver error: ", static_cast<int>(__err));                   \
181     }                                                                             \
182   } while (0)
183 
184 #endif
185 
186 // For CUDA NVRTC
187 //
188 // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
189 // incorrectly produces the error string "NVRTC unknown error."
190 // The following maps it correctly.
191 //
192 // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
193 // in ATen, and we need to use its nvrtcGetErrorString.
194 // See NOTE [ USE OF NVRTC AND DRIVER API ].
195 #define AT_CUDA_NVRTC_CHECK(EXPR)                                                                   \
196   do {                                                                                              \
197     nvrtcResult __err = EXPR;                                                                       \
198     if (__err != NVRTC_SUCCESS) {                                                                   \
199       if (static_cast<int>(__err) != 7) {                                                           \
200         AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err));  \
201       } else {                                                                                      \
202         AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE");                        \
203       }                                                                                             \
204     }                                                                                               \
205   } while (0)
206