1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <exception>
4*da0073e9SAndroid Build Coastguard Worker #include <memory>
5*da0073e9SAndroid Build Coastguard Worker #include <string>
6*da0073e9SAndroid Build Coastguard Worker #include <system_error>
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/FunctionTraits.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/StringUtil.h>
12*da0073e9SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
13*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Export.h>
14*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/runtime/jit_exception.h>
15*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/cpp_stacktraces.h>
16*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/pybind.h>
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker #if defined(USE_DISTRIBUTED)
19*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/c10d/exception.h>
20*da0073e9SAndroid Build Coastguard Worker #endif
21*da0073e9SAndroid Build Coastguard Worker
PyErr_SetString(PyObject * type,const std::string & message)22*da0073e9SAndroid Build Coastguard Worker inline void PyErr_SetString(PyObject* type, const std::string& message) {
23*da0073e9SAndroid Build Coastguard Worker PyErr_SetString(type, message.c_str());
24*da0073e9SAndroid Build Coastguard Worker }
25*da0073e9SAndroid Build Coastguard Worker /// NOTE [ Conversion Cpp Python Warning ]
26*da0073e9SAndroid Build Coastguard Worker /// The warning handler cannot set python warnings immediately
27*da0073e9SAndroid Build Coastguard Worker /// as it requires acquiring the GIL (potential deadlock)
28*da0073e9SAndroid Build Coastguard Worker /// and would need to cleanly exit if the warning raised a
29*da0073e9SAndroid Build Coastguard Worker /// python error. To solve this, we buffer the warnings and
30*da0073e9SAndroid Build Coastguard Worker /// process them when we go back to python.
31*da0073e9SAndroid Build Coastguard Worker /// This requires the two try/catch blocks below to handle the
32*da0073e9SAndroid Build Coastguard Worker /// following cases:
33*da0073e9SAndroid Build Coastguard Worker /// - If there is no Error raised in the inner try/catch, the
34*da0073e9SAndroid Build Coastguard Worker /// buffered warnings are processed as python warnings.
35*da0073e9SAndroid Build Coastguard Worker /// - If they don't raise an error, the function process with the
36*da0073e9SAndroid Build Coastguard Worker /// original return code.
37*da0073e9SAndroid Build Coastguard Worker /// - If any of them raise an error, the error is set (PyErr_*) and
38*da0073e9SAndroid Build Coastguard Worker /// the destructor will raise a cpp exception python_error() that
39*da0073e9SAndroid Build Coastguard Worker /// will be caught by the outer try/catch that will be able to change
40*da0073e9SAndroid Build Coastguard Worker /// the return value of the function to reflect the error.
41*da0073e9SAndroid Build Coastguard Worker /// - If an Error was raised in the inner try/catch, the inner try/catch
42*da0073e9SAndroid Build Coastguard Worker /// must set the python error. The buffered warnings are then
43*da0073e9SAndroid Build Coastguard Worker /// processed as cpp warnings as we cannot predict before hand
44*da0073e9SAndroid Build Coastguard Worker /// whether a python warning will raise an error or not and we
45*da0073e9SAndroid Build Coastguard Worker /// cannot handle two errors at the same time.
46*da0073e9SAndroid Build Coastguard Worker /// This advanced handler will only be used in the current thread.
47*da0073e9SAndroid Build Coastguard Worker /// If any other thread is used, warnings will be processed as
48*da0073e9SAndroid Build Coastguard Worker /// cpp warnings.
49*da0073e9SAndroid Build Coastguard Worker #define HANDLE_TH_ERRORS \
50*da0073e9SAndroid Build Coastguard Worker try { \
51*da0073e9SAndroid Build Coastguard Worker torch::PyWarningHandler __enforce_warning_buffer; \
52*da0073e9SAndroid Build Coastguard Worker try {
53*da0073e9SAndroid Build Coastguard Worker #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
54*da0073e9SAndroid Build Coastguard Worker catch (const c10::ErrorType& e) { \
55*da0073e9SAndroid Build Coastguard Worker auto msg = torch::get_cpp_stacktraces_enabled() \
56*da0073e9SAndroid Build Coastguard Worker ? e.what() \
57*da0073e9SAndroid Build Coastguard Worker : e.what_without_backtrace(); \
58*da0073e9SAndroid Build Coastguard Worker PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
59*da0073e9SAndroid Build Coastguard Worker retstmnt; \
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker // Only catch torch-specific exceptions
63*da0073e9SAndroid Build Coastguard Worker #define CATCH_CORE_ERRORS(retstmnt) \
64*da0073e9SAndroid Build Coastguard Worker catch (python_error & e) { \
65*da0073e9SAndroid Build Coastguard Worker e.restore(); \
66*da0073e9SAndroid Build Coastguard Worker retstmnt; \
67*da0073e9SAndroid Build Coastguard Worker } \
68*da0073e9SAndroid Build Coastguard Worker catch (py::error_already_set & e) { \
69*da0073e9SAndroid Build Coastguard Worker e.restore(); \
70*da0073e9SAndroid Build Coastguard Worker retstmnt; \
71*da0073e9SAndroid Build Coastguard Worker } \
72*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
73*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
74*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
75*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR( \
76*da0073e9SAndroid Build Coastguard Worker NotImplementedError, PyExc_NotImplementedError, retstmnt) \
77*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
78*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR( \
79*da0073e9SAndroid Build Coastguard Worker OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
80*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR( \
81*da0073e9SAndroid Build Coastguard Worker DistBackendError, THPException_DistBackendError, retstmnt) \
82*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR( \
83*da0073e9SAndroid Build Coastguard Worker DistNetworkError, THPException_DistNetworkError, retstmnt) \
84*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
85*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt) \
86*da0073e9SAndroid Build Coastguard Worker _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
87*da0073e9SAndroid Build Coastguard Worker catch (torch::PyTorchError & e) { \
88*da0073e9SAndroid Build Coastguard Worker auto msg = torch::processErrorMsg(e.what()); \
89*da0073e9SAndroid Build Coastguard Worker PyErr_SetString(e.python_type(), msg); \
90*da0073e9SAndroid Build Coastguard Worker retstmnt; \
91*da0073e9SAndroid Build Coastguard Worker }
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker #define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker #define CATCH_ALL_ERRORS(retstmnt) \
96*da0073e9SAndroid Build Coastguard Worker CATCH_TH_ERRORS(retstmnt) \
97*da0073e9SAndroid Build Coastguard Worker catch (const std::exception& e) { \
98*da0073e9SAndroid Build Coastguard Worker auto msg = torch::processErrorMsg(e.what()); \
99*da0073e9SAndroid Build Coastguard Worker PyErr_SetString(PyExc_RuntimeError, msg); \
100*da0073e9SAndroid Build Coastguard Worker retstmnt; \
101*da0073e9SAndroid Build Coastguard Worker }
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS_PYBIND \
104*da0073e9SAndroid Build Coastguard Worker } \
105*da0073e9SAndroid Build Coastguard Worker catch (...) { \
106*da0073e9SAndroid Build Coastguard Worker __enforce_warning_buffer.set_in_exception(); \
107*da0073e9SAndroid Build Coastguard Worker throw; \
108*da0073e9SAndroid Build Coastguard Worker } \
109*da0073e9SAndroid Build Coastguard Worker } \
110*da0073e9SAndroid Build Coastguard Worker catch (py::error_already_set & e) { \
111*da0073e9SAndroid Build Coastguard Worker throw; \
112*da0073e9SAndroid Build Coastguard Worker } \
113*da0073e9SAndroid Build Coastguard Worker catch (py::builtin_exception & e) { \
114*da0073e9SAndroid Build Coastguard Worker throw; \
115*da0073e9SAndroid Build Coastguard Worker } \
116*da0073e9SAndroid Build Coastguard Worker catch (torch::jit::JITException & e) { \
117*da0073e9SAndroid Build Coastguard Worker throw; \
118*da0073e9SAndroid Build Coastguard Worker } \
119*da0073e9SAndroid Build Coastguard Worker catch (const std::exception& e) { \
120*da0073e9SAndroid Build Coastguard Worker torch::translate_exception_to_python(std::current_exception()); \
121*da0073e9SAndroid Build Coastguard Worker throw py::error_already_set(); \
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS_RET(retval) \
125*da0073e9SAndroid Build Coastguard Worker } \
126*da0073e9SAndroid Build Coastguard Worker catch (...) { \
127*da0073e9SAndroid Build Coastguard Worker __enforce_warning_buffer.set_in_exception(); \
128*da0073e9SAndroid Build Coastguard Worker throw; \
129*da0073e9SAndroid Build Coastguard Worker } \
130*da0073e9SAndroid Build Coastguard Worker } \
131*da0073e9SAndroid Build Coastguard Worker catch (const std::exception& e) { \
132*da0073e9SAndroid Build Coastguard Worker torch::translate_exception_to_python(std::current_exception()); \
133*da0073e9SAndroid Build Coastguard Worker return retval; \
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker extern PyObject *THPException_FatalError, *THPException_LinAlgError,
139*da0073e9SAndroid Build Coastguard Worker *THPException_OutOfMemoryError, *THPException_DistError,
140*da0073e9SAndroid Build Coastguard Worker *THPException_DistBackendError, *THPException_DistNetworkError,
141*da0073e9SAndroid Build Coastguard Worker *THPException_DistStoreError;
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker // Throwing this exception means that the python error flags have been already
144*da0073e9SAndroid Build Coastguard Worker // set and control should be immediately returned to the interpreter.
145*da0073e9SAndroid Build Coastguard Worker struct python_error : public std::exception {
146*da0073e9SAndroid Build Coastguard Worker python_error() = default;
147*da0073e9SAndroid Build Coastguard Worker
python_errorpython_error148*da0073e9SAndroid Build Coastguard Worker python_error(const python_error& other)
149*da0073e9SAndroid Build Coastguard Worker : type(other.type),
150*da0073e9SAndroid Build Coastguard Worker value(other.value),
151*da0073e9SAndroid Build Coastguard Worker traceback(other.traceback),
152*da0073e9SAndroid Build Coastguard Worker message(other.message) {
153*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_acquire gil;
154*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(type);
155*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(value);
156*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(traceback);
157*da0073e9SAndroid Build Coastguard Worker }
158*da0073e9SAndroid Build Coastguard Worker
python_errorpython_error159*da0073e9SAndroid Build Coastguard Worker python_error(python_error&& other) noexcept
160*da0073e9SAndroid Build Coastguard Worker : type(other.type),
161*da0073e9SAndroid Build Coastguard Worker value(other.value),
162*da0073e9SAndroid Build Coastguard Worker traceback(other.traceback),
163*da0073e9SAndroid Build Coastguard Worker message(std::move(other.message)) {
164*da0073e9SAndroid Build Coastguard Worker other.type = nullptr;
165*da0073e9SAndroid Build Coastguard Worker other.value = nullptr;
166*da0073e9SAndroid Build Coastguard Worker other.traceback = nullptr;
167*da0073e9SAndroid Build Coastguard Worker }
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker python_error& operator=(const python_error& other) = delete;
170*da0073e9SAndroid Build Coastguard Worker python_error& operator=(python_error&& other) = delete;
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-exception-escape)
~python_errorpython_error173*da0073e9SAndroid Build Coastguard Worker ~python_error() override {
174*da0073e9SAndroid Build Coastguard Worker if (type || value || traceback) {
175*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_acquire gil;
176*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(type);
177*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(value);
178*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(traceback);
179*da0073e9SAndroid Build Coastguard Worker }
180*da0073e9SAndroid Build Coastguard Worker }
181*da0073e9SAndroid Build Coastguard Worker
whatpython_error182*da0073e9SAndroid Build Coastguard Worker const char* what() const noexcept override {
183*da0073e9SAndroid Build Coastguard Worker return message.c_str();
184*da0073e9SAndroid Build Coastguard Worker }
185*da0073e9SAndroid Build Coastguard Worker
build_messagepython_error186*da0073e9SAndroid Build Coastguard Worker void build_message() {
187*da0073e9SAndroid Build Coastguard Worker // Ensure we have the GIL.
188*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_acquire gil;
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker // No errors should be set when we enter the function since PyErr_Fetch
191*da0073e9SAndroid Build Coastguard Worker // clears the error indicator.
192*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker // Default message.
195*da0073e9SAndroid Build Coastguard Worker message = "python_error";
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker // Try to retrieve the error message from the value.
198*da0073e9SAndroid Build Coastguard Worker if (value != nullptr) {
199*da0073e9SAndroid Build Coastguard Worker // Reference count should not be zero.
200*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker PyObject* pyStr = PyObject_Str(value);
203*da0073e9SAndroid Build Coastguard Worker if (pyStr != nullptr) {
204*da0073e9SAndroid Build Coastguard Worker PyObject* encodedString =
205*da0073e9SAndroid Build Coastguard Worker PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
206*da0073e9SAndroid Build Coastguard Worker if (encodedString != nullptr) {
207*da0073e9SAndroid Build Coastguard Worker char* bytes = PyBytes_AS_STRING(encodedString);
208*da0073e9SAndroid Build Coastguard Worker if (bytes != nullptr) {
209*da0073e9SAndroid Build Coastguard Worker // Set the message.
210*da0073e9SAndroid Build Coastguard Worker message = std::string(bytes);
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(encodedString);
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(pyStr);
215*da0073e9SAndroid Build Coastguard Worker }
216*da0073e9SAndroid Build Coastguard Worker }
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker // Clear any errors since we don't want to propagate errors for functions
219*da0073e9SAndroid Build Coastguard Worker // that are trying to build a string for the error message.
220*da0073e9SAndroid Build Coastguard Worker PyErr_Clear();
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker /** Saves the exception so that it can be re-thrown on a different thread */
persistpython_error224*da0073e9SAndroid Build Coastguard Worker inline void persist() {
225*da0073e9SAndroid Build Coastguard Worker if (type)
226*da0073e9SAndroid Build Coastguard Worker return; // Don't overwrite exceptions
227*da0073e9SAndroid Build Coastguard Worker // PyErr_Fetch overwrites the pointers
228*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_acquire gil;
229*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(type);
230*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(value);
231*da0073e9SAndroid Build Coastguard Worker Py_XDECREF(traceback);
232*da0073e9SAndroid Build Coastguard Worker PyErr_Fetch(&type, &value, &traceback);
233*da0073e9SAndroid Build Coastguard Worker build_message();
234*da0073e9SAndroid Build Coastguard Worker }
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker /** Sets the current Python error from this exception */
restorepython_error237*da0073e9SAndroid Build Coastguard Worker inline void restore() {
238*da0073e9SAndroid Build Coastguard Worker if (!type)
239*da0073e9SAndroid Build Coastguard Worker return;
240*da0073e9SAndroid Build Coastguard Worker // PyErr_Restore steals references
241*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_acquire gil;
242*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(type);
243*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(value);
244*da0073e9SAndroid Build Coastguard Worker Py_XINCREF(traceback);
245*da0073e9SAndroid Build Coastguard Worker PyErr_Restore(type, value, traceback);
246*da0073e9SAndroid Build Coastguard Worker }
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker PyObject* type{nullptr};
249*da0073e9SAndroid Build Coastguard Worker PyObject* value{nullptr};
250*da0073e9SAndroid Build Coastguard Worker PyObject* traceback{nullptr};
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker // Message to return to the user when 'what()' is invoked.
253*da0073e9SAndroid Build Coastguard Worker std::string message;
254*da0073e9SAndroid Build Coastguard Worker };
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker bool THPException_init(PyObject* module);
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker namespace torch {
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker // Set python current exception from a C++ exception
261*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API std::string processErrorMsg(std::string str);
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker // Abstract base class for exceptions which translate to specific Python types
266*da0073e9SAndroid Build Coastguard Worker struct PyTorchError : public std::exception {
267*da0073e9SAndroid Build Coastguard Worker PyTorchError() = default;
PyTorchErrorPyTorchError268*da0073e9SAndroid Build Coastguard Worker PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
269*da0073e9SAndroid Build Coastguard Worker virtual PyObject* python_type() = 0;
whatPyTorchError270*da0073e9SAndroid Build Coastguard Worker const char* what() const noexcept override {
271*da0073e9SAndroid Build Coastguard Worker return msg.c_str();
272*da0073e9SAndroid Build Coastguard Worker }
273*da0073e9SAndroid Build Coastguard Worker std::string msg;
274*da0073e9SAndroid Build Coastguard Worker };
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker // Declare a printf-like function on gcc & clang
277*da0073e9SAndroid Build Coastguard Worker // The compiler can then warn on invalid format specifiers
278*da0073e9SAndroid Build Coastguard Worker #ifdef __GNUC__
279*da0073e9SAndroid Build Coastguard Worker #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
280*da0073e9SAndroid Build Coastguard Worker __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
281*da0073e9SAndroid Build Coastguard Worker #else
282*da0073e9SAndroid Build Coastguard Worker #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
283*da0073e9SAndroid Build Coastguard Worker #endif
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker // Translates to Python TypeError
286*da0073e9SAndroid Build Coastguard Worker struct TypeError : public PyTorchError {
287*da0073e9SAndroid Build Coastguard Worker using PyTorchError::PyTorchError;
288*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeTypeError289*da0073e9SAndroid Build Coastguard Worker PyObject* python_type() override {
290*da0073e9SAndroid Build Coastguard Worker return PyExc_TypeError;
291*da0073e9SAndroid Build Coastguard Worker }
292*da0073e9SAndroid Build Coastguard Worker };
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker // Translates to Python AttributeError
295*da0073e9SAndroid Build Coastguard Worker struct AttributeError : public PyTorchError {
296*da0073e9SAndroid Build Coastguard Worker AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeAttributeError297*da0073e9SAndroid Build Coastguard Worker PyObject* python_type() override {
298*da0073e9SAndroid Build Coastguard Worker return PyExc_AttributeError;
299*da0073e9SAndroid Build Coastguard Worker }
300*da0073e9SAndroid Build Coastguard Worker };
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker // ATen warning handler for Python
303*da0073e9SAndroid Build Coastguard Worker struct PyWarningHandler {
304*da0073e9SAndroid Build Coastguard Worker // Move actual handler into a separate class with a noexcept
305*da0073e9SAndroid Build Coastguard Worker // destructor. Otherwise, we need to force all WarningHandler
306*da0073e9SAndroid Build Coastguard Worker // subclasses to have a noexcept(false) destructor.
307*da0073e9SAndroid Build Coastguard Worker struct InternalHandler : at::WarningHandler {
308*da0073e9SAndroid Build Coastguard Worker ~InternalHandler() override = default;
309*da0073e9SAndroid Build Coastguard Worker void process(const c10::Warning& warning) override;
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker std::vector<c10::Warning> warning_buffer_;
312*da0073e9SAndroid Build Coastguard Worker };
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker public:
315*da0073e9SAndroid Build Coastguard Worker /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
316*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API PyWarningHandler() noexcept(true);
317*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-exception-escape)
318*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker /** Call if an exception has been thrown
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker * Necessary to determine if it is safe to throw from the desctructor since
323*da0073e9SAndroid Build Coastguard Worker * std::uncaught_exception is buggy on some platforms and generally
324*da0073e9SAndroid Build Coastguard Worker * unreliable across dynamic library calls.
325*da0073e9SAndroid Build Coastguard Worker */
set_in_exceptionPyWarningHandler326*da0073e9SAndroid Build Coastguard Worker void set_in_exception() {
327*da0073e9SAndroid Build Coastguard Worker in_exception_ = true;
328*da0073e9SAndroid Build Coastguard Worker }
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker private:
331*da0073e9SAndroid Build Coastguard Worker InternalHandler internal_handler_;
332*da0073e9SAndroid Build Coastguard Worker at::WarningHandler* prev_handler_;
333*da0073e9SAndroid Build Coastguard Worker bool in_exception_;
334*da0073e9SAndroid Build Coastguard Worker };
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker namespace detail {
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker struct noop_gil_scoped_release {
339*da0073e9SAndroid Build Coastguard Worker // user-defined constructor (i.e. not defaulted) to avoid
340*da0073e9SAndroid Build Coastguard Worker // unused-variable warnings at usage sites of this class
noop_gil_scoped_releasenoop_gil_scoped_release341*da0073e9SAndroid Build Coastguard Worker noop_gil_scoped_release() {}
342*da0073e9SAndroid Build Coastguard Worker };
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker template <bool release_gil>
345*da0073e9SAndroid Build Coastguard Worker using conditional_gil_scoped_release = std::conditional_t<
346*da0073e9SAndroid Build Coastguard Worker release_gil,
347*da0073e9SAndroid Build Coastguard Worker pybind11::gil_scoped_release,
348*da0073e9SAndroid Build Coastguard Worker noop_gil_scoped_release>;
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker template <typename Func, size_t i>
351*da0073e9SAndroid Build Coastguard Worker using Arg = typename invoke_traits<Func>::template arg<i>::type;
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker template <typename Func, size_t... Is, bool release_gil>
wrap_pybind_function_impl_(Func && f,std::index_sequence<Is...>,std::bool_constant<release_gil>)354*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function_impl_(
355*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
356*da0073e9SAndroid Build Coastguard Worker Func&& f,
357*da0073e9SAndroid Build Coastguard Worker std::index_sequence<Is...>,
358*da0073e9SAndroid Build Coastguard Worker std::bool_constant<release_gil>) {
359*da0073e9SAndroid Build Coastguard Worker namespace py = pybind11;
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker // f=f is needed to handle function references on older compilers
362*da0073e9SAndroid Build Coastguard Worker return [f = std::forward<Func>(f)](Arg<Func, Is>... args) {
363*da0073e9SAndroid Build Coastguard Worker HANDLE_TH_ERRORS
364*da0073e9SAndroid Build Coastguard Worker conditional_gil_scoped_release<release_gil> no_gil;
365*da0073e9SAndroid Build Coastguard Worker return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
366*da0073e9SAndroid Build Coastguard Worker END_HANDLE_TH_ERRORS_PYBIND
367*da0073e9SAndroid Build Coastguard Worker };
368*da0073e9SAndroid Build Coastguard Worker }
369*da0073e9SAndroid Build Coastguard Worker } // namespace detail
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker // Wrap a function with TH error and warning handling.
372*da0073e9SAndroid Build Coastguard Worker // Returns a function object suitable for registering with pybind11.
373*da0073e9SAndroid Build Coastguard Worker template <typename Func>
wrap_pybind_function(Func && f)374*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function(Func&& f) {
375*da0073e9SAndroid Build Coastguard Worker using traits = invoke_traits<Func>;
376*da0073e9SAndroid Build Coastguard Worker return torch::detail::wrap_pybind_function_impl_(
377*da0073e9SAndroid Build Coastguard Worker std::forward<Func>(f),
378*da0073e9SAndroid Build Coastguard Worker std::make_index_sequence<traits::arity>{},
379*da0073e9SAndroid Build Coastguard Worker std::false_type{});
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker // Wrap a function with TH error, warning handling and releases the GIL.
383*da0073e9SAndroid Build Coastguard Worker // Returns a function object suitable for registering with pybind11.
384*da0073e9SAndroid Build Coastguard Worker template <typename Func>
wrap_pybind_function_no_gil(Func && f)385*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function_no_gil(Func&& f) {
386*da0073e9SAndroid Build Coastguard Worker using traits = invoke_traits<Func>;
387*da0073e9SAndroid Build Coastguard Worker return torch::detail::wrap_pybind_function_impl_(
388*da0073e9SAndroid Build Coastguard Worker std::forward<Func>(f),
389*da0073e9SAndroid Build Coastguard Worker std::make_index_sequence<traits::arity>{},
390*da0073e9SAndroid Build Coastguard Worker std::true_type{});
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker } // namespace torch
394