xref: /aosp_15_r20/external/pytorch/torch/csrc/Exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <exception>
4 #include <memory>
5 #include <string>
6 #include <system_error>
7 
8 #include <ATen/detail/FunctionTraits.h>
9 #include <c10/util/C++17.h>
10 #include <c10/util/Exception.h>
11 #include <c10/util/StringUtil.h>
12 #include <pybind11/pybind11.h>
13 #include <torch/csrc/Export.h>
14 #include <torch/csrc/jit/runtime/jit_exception.h>
15 #include <torch/csrc/utils/cpp_stacktraces.h>
16 #include <torch/csrc/utils/pybind.h>
17 
18 #if defined(USE_DISTRIBUTED)
19 #include <torch/csrc/distributed/c10d/exception.h>
20 #endif
21 
PyErr_SetString(PyObject * type,const std::string & message)22 inline void PyErr_SetString(PyObject* type, const std::string& message) {
23   PyErr_SetString(type, message.c_str());
24 }
25 /// NOTE [ Conversion Cpp Python Warning ]
26 /// The warning handler cannot set python warnings immediately
27 /// as it requires acquiring the GIL (potential deadlock)
28 /// and would need to cleanly exit if the warning raised a
29 /// python error. To solve this, we buffer the warnings and
30 /// process them when we go back to python.
31 /// This requires the two try/catch blocks below to handle the
32 /// following cases:
33 ///   - If there is no Error raised in the inner try/catch, the
34 ///     buffered warnings are processed as python warnings.
35 ///     - If they don't raise an error, the function process with the
36 ///       original return code.
37 ///     - If any of them raise an error, the error is set (PyErr_*) and
38 ///       the destructor will raise a cpp exception python_error() that
39 ///       will be caught by the outer try/catch that will be able to change
40 ///       the return value of the function to reflect the error.
41 ///   - If an Error was raised in the inner try/catch, the inner try/catch
42 ///     must set the python error. The buffered warnings are then
43 ///     processed as cpp warnings as we cannot predict before hand
44 ///     whether a python warning will raise an error or not and we
45 ///     cannot handle two errors at the same time.
46 /// This advanced handler will only be used in the current thread.
47 /// If any other thread is used, warnings will be processed as
48 /// cpp warnings.
49 #define HANDLE_TH_ERRORS                              \
50   try {                                               \
51     torch::PyWarningHandler __enforce_warning_buffer; \
52     try {
53 #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
54   catch (const c10::ErrorType& e) {                                \
55     auto msg = torch::get_cpp_stacktraces_enabled()                \
56         ? e.what()                                                 \
57         : e.what_without_backtrace();                              \
58     PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
59     retstmnt;                                                      \
60   }
61 
62 // Only catch torch-specific exceptions
63 #define CATCH_CORE_ERRORS(retstmnt)                                           \
64   catch (python_error & e) {                                                  \
65     e.restore();                                                              \
66     retstmnt;                                                                 \
67   }                                                                           \
68   catch (py::error_already_set & e) {                                         \
69     e.restore();                                                              \
70     retstmnt;                                                                 \
71   }                                                                           \
72   _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt)                \
73   _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt)                \
74   _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt)                  \
75   _CATCH_GENERIC_ERROR(                                                       \
76       NotImplementedError, PyExc_NotImplementedError, retstmnt)               \
77   _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt)       \
78   _CATCH_GENERIC_ERROR(                                                       \
79       OutOfMemoryError, THPException_OutOfMemoryError, retstmnt)              \
80   _CATCH_GENERIC_ERROR(                                                       \
81       DistBackendError, THPException_DistBackendError, retstmnt)              \
82   _CATCH_GENERIC_ERROR(                                                       \
83       DistNetworkError, THPException_DistNetworkError, retstmnt)              \
84   _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
85   _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt)           \
86   _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt)                   \
87   catch (torch::PyTorchError & e) {                                           \
88     auto msg = torch::processErrorMsg(e.what());                              \
89     PyErr_SetString(e.python_type(), msg);                                    \
90     retstmnt;                                                                 \
91   }
92 
93 #define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
94 
95 #define CATCH_ALL_ERRORS(retstmnt)               \
96   CATCH_TH_ERRORS(retstmnt)                      \
97   catch (const std::exception& e) {              \
98     auto msg = torch::processErrorMsg(e.what()); \
99     PyErr_SetString(PyExc_RuntimeError, msg);    \
100     retstmnt;                                    \
101   }
102 
103 #define END_HANDLE_TH_ERRORS_PYBIND                                 \
104   }                                                                 \
105   catch (...) {                                                     \
106     __enforce_warning_buffer.set_in_exception();                    \
107     throw;                                                          \
108   }                                                                 \
109   }                                                                 \
110   catch (py::error_already_set & e) {                               \
111     throw;                                                          \
112   }                                                                 \
113   catch (py::builtin_exception & e) {                               \
114     throw;                                                          \
115   }                                                                 \
116   catch (torch::jit::JITException & e) {                            \
117     throw;                                                          \
118   }                                                                 \
119   catch (const std::exception& e) {                                 \
120     torch::translate_exception_to_python(std::current_exception()); \
121     throw py::error_already_set();                                  \
122   }
123 
124 #define END_HANDLE_TH_ERRORS_RET(retval)                            \
125   }                                                                 \
126   catch (...) {                                                     \
127     __enforce_warning_buffer.set_in_exception();                    \
128     throw;                                                          \
129   }                                                                 \
130   }                                                                 \
131   catch (const std::exception& e) {                                 \
132     torch::translate_exception_to_python(std::current_exception()); \
133     return retval;                                                  \
134   }
135 
136 #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
137 
138 extern PyObject *THPException_FatalError, *THPException_LinAlgError,
139     *THPException_OutOfMemoryError, *THPException_DistError,
140     *THPException_DistBackendError, *THPException_DistNetworkError,
141     *THPException_DistStoreError;
142 
143 // Throwing this exception means that the python error flags have been already
144 // set and control should be immediately returned to the interpreter.
145 struct python_error : public std::exception {
146   python_error() = default;
147 
python_errorpython_error148   python_error(const python_error& other)
149       : type(other.type),
150         value(other.value),
151         traceback(other.traceback),
152         message(other.message) {
153     pybind11::gil_scoped_acquire gil;
154     Py_XINCREF(type);
155     Py_XINCREF(value);
156     Py_XINCREF(traceback);
157   }
158 
python_errorpython_error159   python_error(python_error&& other) noexcept
160       : type(other.type),
161         value(other.value),
162         traceback(other.traceback),
163         message(std::move(other.message)) {
164     other.type = nullptr;
165     other.value = nullptr;
166     other.traceback = nullptr;
167   }
168 
169   python_error& operator=(const python_error& other) = delete;
170   python_error& operator=(python_error&& other) = delete;
171 
172   // NOLINTNEXTLINE(bugprone-exception-escape)
~python_errorpython_error173   ~python_error() override {
174     if (type || value || traceback) {
175       pybind11::gil_scoped_acquire gil;
176       Py_XDECREF(type);
177       Py_XDECREF(value);
178       Py_XDECREF(traceback);
179     }
180   }
181 
whatpython_error182   const char* what() const noexcept override {
183     return message.c_str();
184   }
185 
build_messagepython_error186   void build_message() {
187     // Ensure we have the GIL.
188     pybind11::gil_scoped_acquire gil;
189 
190     // No errors should be set when we enter the function since PyErr_Fetch
191     // clears the error indicator.
192     TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
193 
194     // Default message.
195     message = "python_error";
196 
197     // Try to retrieve the error message from the value.
198     if (value != nullptr) {
199       // Reference count should not be zero.
200       TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
201 
202       PyObject* pyStr = PyObject_Str(value);
203       if (pyStr != nullptr) {
204         PyObject* encodedString =
205             PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
206         if (encodedString != nullptr) {
207           char* bytes = PyBytes_AS_STRING(encodedString);
208           if (bytes != nullptr) {
209             // Set the message.
210             message = std::string(bytes);
211           }
212           Py_XDECREF(encodedString);
213         }
214         Py_XDECREF(pyStr);
215       }
216     }
217 
218     // Clear any errors since we don't want to propagate errors for functions
219     // that are trying to build a string for the error message.
220     PyErr_Clear();
221   }
222 
223   /** Saves the exception so that it can be re-thrown on a different thread */
persistpython_error224   inline void persist() {
225     if (type)
226       return; // Don't overwrite exceptions
227     // PyErr_Fetch overwrites the pointers
228     pybind11::gil_scoped_acquire gil;
229     Py_XDECREF(type);
230     Py_XDECREF(value);
231     Py_XDECREF(traceback);
232     PyErr_Fetch(&type, &value, &traceback);
233     build_message();
234   }
235 
236   /** Sets the current Python error from this exception */
restorepython_error237   inline void restore() {
238     if (!type)
239       return;
240     // PyErr_Restore steals references
241     pybind11::gil_scoped_acquire gil;
242     Py_XINCREF(type);
243     Py_XINCREF(value);
244     Py_XINCREF(traceback);
245     PyErr_Restore(type, value, traceback);
246   }
247 
248   PyObject* type{nullptr};
249   PyObject* value{nullptr};
250   PyObject* traceback{nullptr};
251 
252   // Message to return to the user when 'what()' is invoked.
253   std::string message;
254 };
255 
256 bool THPException_init(PyObject* module);
257 
258 namespace torch {
259 
260 // Set python current exception from a C++ exception
261 TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
262 
263 TORCH_PYTHON_API std::string processErrorMsg(std::string str);
264 
265 // Abstract base class for exceptions which translate to specific Python types
266 struct PyTorchError : public std::exception {
267   PyTorchError() = default;
PyTorchErrorPyTorchError268   PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
269   virtual PyObject* python_type() = 0;
whatPyTorchError270   const char* what() const noexcept override {
271     return msg.c_str();
272   }
273   std::string msg;
274 };
275 
276 // Declare a printf-like function on gcc & clang
277 // The compiler can then warn on invalid format specifiers
278 #ifdef __GNUC__
279 #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
280   __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
281 #else
282 #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
283 #endif
284 
285 // Translates to Python TypeError
286 struct TypeError : public PyTorchError {
287   using PyTorchError::PyTorchError;
288   TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeTypeError289   PyObject* python_type() override {
290     return PyExc_TypeError;
291   }
292 };
293 
294 // Translates to Python AttributeError
295 struct AttributeError : public PyTorchError {
296   AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeAttributeError297   PyObject* python_type() override {
298     return PyExc_AttributeError;
299   }
300 };
301 
302 // ATen warning handler for Python
303 struct PyWarningHandler {
304   // Move actual handler into a separate class with a noexcept
305   // destructor. Otherwise, we need to force all WarningHandler
306   // subclasses to have a noexcept(false) destructor.
307   struct InternalHandler : at::WarningHandler {
308     ~InternalHandler() override = default;
309     void process(const c10::Warning& warning) override;
310 
311     std::vector<c10::Warning> warning_buffer_;
312   };
313 
314  public:
315   /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
316   TORCH_PYTHON_API PyWarningHandler() noexcept(true);
317   // NOLINTNEXTLINE(bugprone-exception-escape)
318   TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);
319 
320   /** Call if an exception has been thrown
321 
322    *  Necessary to determine if it is safe to throw from the desctructor since
323    *  std::uncaught_exception is buggy on some platforms and generally
324    *  unreliable across dynamic library calls.
325    */
set_in_exceptionPyWarningHandler326   void set_in_exception() {
327     in_exception_ = true;
328   }
329 
330  private:
331   InternalHandler internal_handler_;
332   at::WarningHandler* prev_handler_;
333   bool in_exception_;
334 };
335 
336 namespace detail {
337 
338 struct noop_gil_scoped_release {
339   // user-defined constructor (i.e. not defaulted) to avoid
340   // unused-variable warnings at usage sites of this class
noop_gil_scoped_releasenoop_gil_scoped_release341   noop_gil_scoped_release() {}
342 };
343 
344 template <bool release_gil>
345 using conditional_gil_scoped_release = std::conditional_t<
346     release_gil,
347     pybind11::gil_scoped_release,
348     noop_gil_scoped_release>;
349 
350 template <typename Func, size_t i>
351 using Arg = typename invoke_traits<Func>::template arg<i>::type;
352 
353 template <typename Func, size_t... Is, bool release_gil>
wrap_pybind_function_impl_(Func && f,std::index_sequence<Is...>,std::bool_constant<release_gil>)354 auto wrap_pybind_function_impl_(
355     // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
356     Func&& f,
357     std::index_sequence<Is...>,
358     std::bool_constant<release_gil>) {
359   namespace py = pybind11;
360 
361   // f=f is needed to handle function references on older compilers
362   return [f = std::forward<Func>(f)](Arg<Func, Is>... args) {
363     HANDLE_TH_ERRORS
364     conditional_gil_scoped_release<release_gil> no_gil;
365     return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
366     END_HANDLE_TH_ERRORS_PYBIND
367   };
368 }
369 } // namespace detail
370 
371 // Wrap a function with TH error and warning handling.
372 // Returns a function object suitable for registering with pybind11.
373 template <typename Func>
wrap_pybind_function(Func && f)374 auto wrap_pybind_function(Func&& f) {
375   using traits = invoke_traits<Func>;
376   return torch::detail::wrap_pybind_function_impl_(
377       std::forward<Func>(f),
378       std::make_index_sequence<traits::arity>{},
379       std::false_type{});
380 }
381 
382 // Wrap a function with TH error, warning handling and releases the GIL.
383 // Returns a function object suitable for registering with pybind11.
384 template <typename Func>
wrap_pybind_function_no_gil(Func && f)385 auto wrap_pybind_function_no_gil(Func&& f) {
386   using traits = invoke_traits<Func>;
387   return torch::detail::wrap_pybind_function_impl_(
388       std::forward<Func>(f),
389       std::make_index_sequence<traits::arity>{},
390       std::true_type{});
391 }
392 
393 } // namespace torch
394