1 #pragma once
2
3 #include <exception>
4 #include <memory>
5 #include <string>
6 #include <system_error>
7
8 #include <ATen/detail/FunctionTraits.h>
9 #include <c10/util/C++17.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/StringUtil.h>
12 #include <pybind11/pybind11.h>
13 #include <torch/csrc/Export.h>
14 #include <torch/csrc/jit/runtime/jit_exception.h>
15 #include <torch/csrc/utils/cpp_stacktraces.h>
16 #include <torch/csrc/utils/pybind.h>
17
18 #if defined(USE_DISTRIBUTED)
19 #include <torch/csrc/distributed/c10d/exception.h>
20 #endif
21
PyErr_SetString(PyObject * type,const std::string & message)22 inline void PyErr_SetString(PyObject* type, const std::string& message) {
23 PyErr_SetString(type, message.c_str());
24 }
25 /// NOTE [ Conversion Cpp Python Warning ]
26 /// The warning handler cannot set python warnings immediately
27 /// as it requires acquiring the GIL (potential deadlock)
28 /// and would need to cleanly exit if the warning raised a
29 /// python error. To solve this, we buffer the warnings and
30 /// process them when we go back to python.
31 /// This requires the two try/catch blocks below to handle the
32 /// following cases:
33 /// - If there is no Error raised in the inner try/catch, the
34 /// buffered warnings are processed as python warnings.
35 /// - If they don't raise an error, the function process with the
36 /// original return code.
37 /// - If any of them raise an error, the error is set (PyErr_*) and
38 /// the destructor will raise a cpp exception python_error() that
39 /// will be caught by the outer try/catch that will be able to change
40 /// the return value of the function to reflect the error.
41 /// - If an Error was raised in the inner try/catch, the inner try/catch
42 /// must set the python error. The buffered warnings are then
43 /// processed as cpp warnings as we cannot predict before hand
44 /// whether a python warning will raise an error or not and we
45 /// cannot handle two errors at the same time.
46 /// This advanced handler will only be used in the current thread.
47 /// If any other thread is used, warnings will be processed as
48 /// cpp warnings.
49 #define HANDLE_TH_ERRORS \
50 try { \
51 torch::PyWarningHandler __enforce_warning_buffer; \
52 try {
53 #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
54 catch (const c10::ErrorType& e) { \
55 auto msg = torch::get_cpp_stacktraces_enabled() \
56 ? e.what() \
57 : e.what_without_backtrace(); \
58 PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
59 retstmnt; \
60 }
61
62 // Only catch torch-specific exceptions
63 #define CATCH_CORE_ERRORS(retstmnt) \
64 catch (python_error & e) { \
65 e.restore(); \
66 retstmnt; \
67 } \
68 catch (py::error_already_set & e) { \
69 e.restore(); \
70 retstmnt; \
71 } \
72 _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
73 _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
74 _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
75 _CATCH_GENERIC_ERROR( \
76 NotImplementedError, PyExc_NotImplementedError, retstmnt) \
77 _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
78 _CATCH_GENERIC_ERROR( \
79 OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
80 _CATCH_GENERIC_ERROR( \
81 DistBackendError, THPException_DistBackendError, retstmnt) \
82 _CATCH_GENERIC_ERROR( \
83 DistNetworkError, THPException_DistNetworkError, retstmnt) \
84 _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
85 _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \
86 _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
87 catch (torch::PyTorchError & e) { \
88 auto msg = torch::processErrorMsg(e.what()); \
89 PyErr_SetString(e.python_type(), msg); \
90 retstmnt; \
91 }
92
93 #define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
94
95 #define CATCH_ALL_ERRORS(retstmnt) \
96 CATCH_TH_ERRORS(retstmnt) \
97 catch (const std::exception& e) { \
98 auto msg = torch::processErrorMsg(e.what()); \
99 PyErr_SetString(PyExc_RuntimeError, msg); \
100 retstmnt; \
101 }
102
103 #define END_HANDLE_TH_ERRORS_PYBIND \
104 } \
105 catch (...) { \
106 __enforce_warning_buffer.set_in_exception(); \
107 throw; \
108 } \
109 } \
110 catch (py::error_already_set & e) { \
111 throw; \
112 } \
113 catch (py::builtin_exception & e) { \
114 throw; \
115 } \
116 catch (torch::jit::JITException & e) { \
117 throw; \
118 } \
119 catch (const std::exception& e) { \
120 torch::translate_exception_to_python(std::current_exception()); \
121 throw py::error_already_set(); \
122 }
123
124 #define END_HANDLE_TH_ERRORS_RET(retval) \
125 } \
126 catch (...) { \
127 __enforce_warning_buffer.set_in_exception(); \
128 throw; \
129 } \
130 } \
131 catch (const std::exception& e) { \
132 torch::translate_exception_to_python(std::current_exception()); \
133 return retval; \
134 }
135
136 #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
137
138 extern PyObject *THPException_FatalError, *THPException_LinAlgError,
139 *THPException_OutOfMemoryError, *THPException_DistError,
140 *THPException_DistBackendError, *THPException_DistNetworkError,
141 *THPException_DistStoreError;
142
143 // Throwing this exception means that the python error flags have been already
144 // set and control should be immediately returned to the interpreter.
145 struct python_error : public std::exception {
146 python_error() = default;
147
python_errorpython_error148 python_error(const python_error& other)
149 : type(other.type),
150 value(other.value),
151 traceback(other.traceback),
152 message(other.message) {
153 pybind11::gil_scoped_acquire gil;
154 Py_XINCREF(type);
155 Py_XINCREF(value);
156 Py_XINCREF(traceback);
157 }
158
python_errorpython_error159 python_error(python_error&& other) noexcept
160 : type(other.type),
161 value(other.value),
162 traceback(other.traceback),
163 message(std::move(other.message)) {
164 other.type = nullptr;
165 other.value = nullptr;
166 other.traceback = nullptr;
167 }
168
169 python_error& operator=(const python_error& other) = delete;
170 python_error& operator=(python_error&& other) = delete;
171
172 // NOLINTNEXTLINE(bugprone-exception-escape)
~python_errorpython_error173 ~python_error() override {
174 if (type || value || traceback) {
175 pybind11::gil_scoped_acquire gil;
176 Py_XDECREF(type);
177 Py_XDECREF(value);
178 Py_XDECREF(traceback);
179 }
180 }
181
whatpython_error182 const char* what() const noexcept override {
183 return message.c_str();
184 }
185
build_messagepython_error186 void build_message() {
187 // Ensure we have the GIL.
188 pybind11::gil_scoped_acquire gil;
189
190 // No errors should be set when we enter the function since PyErr_Fetch
191 // clears the error indicator.
192 TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
193
194 // Default message.
195 message = "python_error";
196
197 // Try to retrieve the error message from the value.
198 if (value != nullptr) {
199 // Reference count should not be zero.
200 TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
201
202 PyObject* pyStr = PyObject_Str(value);
203 if (pyStr != nullptr) {
204 PyObject* encodedString =
205 PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
206 if (encodedString != nullptr) {
207 char* bytes = PyBytes_AS_STRING(encodedString);
208 if (bytes != nullptr) {
209 // Set the message.
210 message = std::string(bytes);
211 }
212 Py_XDECREF(encodedString);
213 }
214 Py_XDECREF(pyStr);
215 }
216 }
217
218 // Clear any errors since we don't want to propagate errors for functions
219 // that are trying to build a string for the error message.
220 PyErr_Clear();
221 }
222
223 /** Saves the exception so that it can be re-thrown on a different thread */
persistpython_error224 inline void persist() {
225 if (type)
226 return; // Don't overwrite exceptions
227 // PyErr_Fetch overwrites the pointers
228 pybind11::gil_scoped_acquire gil;
229 Py_XDECREF(type);
230 Py_XDECREF(value);
231 Py_XDECREF(traceback);
232 PyErr_Fetch(&type, &value, &traceback);
233 build_message();
234 }
235
236 /** Sets the current Python error from this exception */
restorepython_error237 inline void restore() {
238 if (!type)
239 return;
240 // PyErr_Restore steals references
241 pybind11::gil_scoped_acquire gil;
242 Py_XINCREF(type);
243 Py_XINCREF(value);
244 Py_XINCREF(traceback);
245 PyErr_Restore(type, value, traceback);
246 }
247
248 PyObject* type{nullptr};
249 PyObject* value{nullptr};
250 PyObject* traceback{nullptr};
251
252 // Message to return to the user when 'what()' is invoked.
253 std::string message;
254 };
255
256 bool THPException_init(PyObject* module);
257
258 namespace torch {
259
260 // Set python current exception from a C++ exception
261 TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
262
263 TORCH_PYTHON_API std::string processErrorMsg(std::string str);
264
265 // Abstract base class for exceptions which translate to specific Python types
266 struct PyTorchError : public std::exception {
267 PyTorchError() = default;
PyTorchErrorPyTorchError268 PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
269 virtual PyObject* python_type() = 0;
whatPyTorchError270 const char* what() const noexcept override {
271 return msg.c_str();
272 }
273 std::string msg;
274 };
275
276 // Declare a printf-like function on gcc & clang
277 // The compiler can then warn on invalid format specifiers
278 #ifdef __GNUC__
279 #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
280 __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
281 #else
282 #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
283 #endif
284
285 // Translates to Python TypeError
286 struct TypeError : public PyTorchError {
287 using PyTorchError::PyTorchError;
288 TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeTypeError289 PyObject* python_type() override {
290 return PyExc_TypeError;
291 }
292 };
293
294 // Translates to Python AttributeError
295 struct AttributeError : public PyTorchError {
296 AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeAttributeError297 PyObject* python_type() override {
298 return PyExc_AttributeError;
299 }
300 };
301
302 // ATen warning handler for Python
303 struct PyWarningHandler {
304 // Move actual handler into a separate class with a noexcept
305 // destructor. Otherwise, we need to force all WarningHandler
306 // subclasses to have a noexcept(false) destructor.
307 struct InternalHandler : at::WarningHandler {
308 ~InternalHandler() override = default;
309 void process(const c10::Warning& warning) override;
310
311 std::vector<c10::Warning> warning_buffer_;
312 };
313
314 public:
315 /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
316 TORCH_PYTHON_API PyWarningHandler() noexcept(true);
317 // NOLINTNEXTLINE(bugprone-exception-escape)
318 TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);
319
320 /** Call if an exception has been thrown
321
322 * Necessary to determine if it is safe to throw from the desctructor since
323 * std::uncaught_exception is buggy on some platforms and generally
324 * unreliable across dynamic library calls.
325 */
set_in_exceptionPyWarningHandler326 void set_in_exception() {
327 in_exception_ = true;
328 }
329
330 private:
331 InternalHandler internal_handler_;
332 at::WarningHandler* prev_handler_;
333 bool in_exception_;
334 };
335
336 namespace detail {
337
338 struct noop_gil_scoped_release {
339 // user-defined constructor (i.e. not defaulted) to avoid
340 // unused-variable warnings at usage sites of this class
noop_gil_scoped_releasenoop_gil_scoped_release341 noop_gil_scoped_release() {}
342 };
343
344 template <bool release_gil>
345 using conditional_gil_scoped_release = std::conditional_t<
346 release_gil,
347 pybind11::gil_scoped_release,
348 noop_gil_scoped_release>;
349
350 template <typename Func, size_t i>
351 using Arg = typename invoke_traits<Func>::template arg<i>::type;
352
353 template <typename Func, size_t... Is, bool release_gil>
wrap_pybind_function_impl_(Func && f,std::index_sequence<Is...>,std::bool_constant<release_gil>)354 auto wrap_pybind_function_impl_(
355 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
356 Func&& f,
357 std::index_sequence<Is...>,
358 std::bool_constant<release_gil>) {
359 namespace py = pybind11;
360
361 // f=f is needed to handle function references on older compilers
362 return [f = std::forward<Func>(f)](Arg<Func, Is>... args) {
363 HANDLE_TH_ERRORS
364 conditional_gil_scoped_release<release_gil> no_gil;
365 return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
366 END_HANDLE_TH_ERRORS_PYBIND
367 };
368 }
369 } // namespace detail
370
371 // Wrap a function with TH error and warning handling.
372 // Returns a function object suitable for registering with pybind11.
373 template <typename Func>
wrap_pybind_function(Func && f)374 auto wrap_pybind_function(Func&& f) {
375 using traits = invoke_traits<Func>;
376 return torch::detail::wrap_pybind_function_impl_(
377 std::forward<Func>(f),
378 std::make_index_sequence<traits::arity>{},
379 std::false_type{});
380 }
381
382 // Wrap a function with TH error, warning handling and releases the GIL.
383 // Returns a function object suitable for registering with pybind11.
384 template <typename Func>
wrap_pybind_function_no_gil(Func && f)385 auto wrap_pybind_function_no_gil(Func&& f) {
386 using traits = invoke_traits<Func>;
387 return torch::detail::wrap_pybind_function_impl_(
388 std::forward<Func>(f),
389 std::make_index_sequence<traits::arity>{},
390 std::true_type{});
391 }
392
393 } // namespace torch
394