xref: /aosp_15_r20/external/pytorch/torch/csrc/Exceptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <exception>
4*da0073e9SAndroid Build Coastguard Worker #include <memory>
5*da0073e9SAndroid Build Coastguard Worker #include <string>
6*da0073e9SAndroid Build Coastguard Worker #include <system_error>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <ATen/detail/FunctionTraits.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
10*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
11*da0073e9SAndroid Build Coastguard Worker #include <c10/util/StringUtil.h>
12*da0073e9SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
13*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/Export.h>
14*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/runtime/jit_exception.h>
15*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/cpp_stacktraces.h>
16*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/pybind.h>
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker #if defined(USE_DISTRIBUTED)
19*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/distributed/c10d/exception.h>
20*da0073e9SAndroid Build Coastguard Worker #endif
21*da0073e9SAndroid Build Coastguard Worker 
PyErr_SetString(PyObject * type,const std::string & message)22*da0073e9SAndroid Build Coastguard Worker inline void PyErr_SetString(PyObject* type, const std::string& message) {
23*da0073e9SAndroid Build Coastguard Worker   PyErr_SetString(type, message.c_str());
24*da0073e9SAndroid Build Coastguard Worker }
25*da0073e9SAndroid Build Coastguard Worker /// NOTE [ Conversion Cpp Python Warning ]
26*da0073e9SAndroid Build Coastguard Worker /// The warning handler cannot set python warnings immediately
27*da0073e9SAndroid Build Coastguard Worker /// as it requires acquiring the GIL (potential deadlock)
28*da0073e9SAndroid Build Coastguard Worker /// and would need to cleanly exit if the warning raised a
29*da0073e9SAndroid Build Coastguard Worker /// python error. To solve this, we buffer the warnings and
30*da0073e9SAndroid Build Coastguard Worker /// process them when we go back to python.
31*da0073e9SAndroid Build Coastguard Worker /// This requires the two try/catch blocks below to handle the
32*da0073e9SAndroid Build Coastguard Worker /// following cases:
33*da0073e9SAndroid Build Coastguard Worker ///   - If there is no Error raised in the inner try/catch, the
34*da0073e9SAndroid Build Coastguard Worker ///     buffered warnings are processed as python warnings.
35*da0073e9SAndroid Build Coastguard Worker ///     - If they don't raise an error, the function process with the
36*da0073e9SAndroid Build Coastguard Worker ///       original return code.
37*da0073e9SAndroid Build Coastguard Worker ///     - If any of them raise an error, the error is set (PyErr_*) and
38*da0073e9SAndroid Build Coastguard Worker ///       the destructor will raise a cpp exception python_error() that
39*da0073e9SAndroid Build Coastguard Worker ///       will be caught by the outer try/catch that will be able to change
40*da0073e9SAndroid Build Coastguard Worker ///       the return value of the function to reflect the error.
41*da0073e9SAndroid Build Coastguard Worker ///   - If an Error was raised in the inner try/catch, the inner try/catch
42*da0073e9SAndroid Build Coastguard Worker ///     must set the python error. The buffered warnings are then
43*da0073e9SAndroid Build Coastguard Worker ///     processed as cpp warnings as we cannot predict before hand
44*da0073e9SAndroid Build Coastguard Worker ///     whether a python warning will raise an error or not and we
45*da0073e9SAndroid Build Coastguard Worker ///     cannot handle two errors at the same time.
46*da0073e9SAndroid Build Coastguard Worker /// This advanced handler will only be used in the current thread.
47*da0073e9SAndroid Build Coastguard Worker /// If any other thread is used, warnings will be processed as
48*da0073e9SAndroid Build Coastguard Worker /// cpp warnings.
49*da0073e9SAndroid Build Coastguard Worker #define HANDLE_TH_ERRORS                              \
50*da0073e9SAndroid Build Coastguard Worker   try {                                               \
51*da0073e9SAndroid Build Coastguard Worker     torch::PyWarningHandler __enforce_warning_buffer; \
52*da0073e9SAndroid Build Coastguard Worker     try {
53*da0073e9SAndroid Build Coastguard Worker #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
54*da0073e9SAndroid Build Coastguard Worker   catch (const c10::ErrorType& e) {                                \
55*da0073e9SAndroid Build Coastguard Worker     auto msg = torch::get_cpp_stacktraces_enabled()                \
56*da0073e9SAndroid Build Coastguard Worker         ? e.what()                                                 \
57*da0073e9SAndroid Build Coastguard Worker         : e.what_without_backtrace();                              \
58*da0073e9SAndroid Build Coastguard Worker     PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
59*da0073e9SAndroid Build Coastguard Worker     retstmnt;                                                      \
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker // Only catch torch-specific exceptions
63*da0073e9SAndroid Build Coastguard Worker #define CATCH_CORE_ERRORS(retstmnt)                                           \
64*da0073e9SAndroid Build Coastguard Worker   catch (python_error & e) {                                                  \
65*da0073e9SAndroid Build Coastguard Worker     e.restore();                                                              \
66*da0073e9SAndroid Build Coastguard Worker     retstmnt;                                                                 \
67*da0073e9SAndroid Build Coastguard Worker   }                                                                           \
68*da0073e9SAndroid Build Coastguard Worker   catch (py::error_already_set & e) {                                         \
69*da0073e9SAndroid Build Coastguard Worker     e.restore();                                                              \
70*da0073e9SAndroid Build Coastguard Worker     retstmnt;                                                                 \
71*da0073e9SAndroid Build Coastguard Worker   }                                                                           \
72*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt)                \
73*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt)                \
74*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt)                  \
75*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(                                                       \
76*da0073e9SAndroid Build Coastguard Worker       NotImplementedError, PyExc_NotImplementedError, retstmnt)               \
77*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt)       \
78*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(                                                       \
79*da0073e9SAndroid Build Coastguard Worker       OutOfMemoryError, THPException_OutOfMemoryError, retstmnt)              \
80*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(                                                       \
81*da0073e9SAndroid Build Coastguard Worker       DistBackendError, THPException_DistBackendError, retstmnt)              \
82*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(                                                       \
83*da0073e9SAndroid Build Coastguard Worker       DistNetworkError, THPException_DistNetworkError, retstmnt)              \
84*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(DistStoreError, THPException_DistStoreError, retstmnt) \
85*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(DistError, THPException_DistError, retstmnt)           \
86*da0073e9SAndroid Build Coastguard Worker   _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt)                   \
87*da0073e9SAndroid Build Coastguard Worker   catch (torch::PyTorchError & e) {                                           \
88*da0073e9SAndroid Build Coastguard Worker     auto msg = torch::processErrorMsg(e.what());                              \
89*da0073e9SAndroid Build Coastguard Worker     PyErr_SetString(e.python_type(), msg);                                    \
90*da0073e9SAndroid Build Coastguard Worker     retstmnt;                                                                 \
91*da0073e9SAndroid Build Coastguard Worker   }
92*da0073e9SAndroid Build Coastguard Worker 
93*da0073e9SAndroid Build Coastguard Worker #define CATCH_TH_ERRORS(retstmnt) CATCH_CORE_ERRORS(retstmnt)
94*da0073e9SAndroid Build Coastguard Worker 
95*da0073e9SAndroid Build Coastguard Worker #define CATCH_ALL_ERRORS(retstmnt)               \
96*da0073e9SAndroid Build Coastguard Worker   CATCH_TH_ERRORS(retstmnt)                      \
97*da0073e9SAndroid Build Coastguard Worker   catch (const std::exception& e) {              \
98*da0073e9SAndroid Build Coastguard Worker     auto msg = torch::processErrorMsg(e.what()); \
99*da0073e9SAndroid Build Coastguard Worker     PyErr_SetString(PyExc_RuntimeError, msg);    \
100*da0073e9SAndroid Build Coastguard Worker     retstmnt;                                    \
101*da0073e9SAndroid Build Coastguard Worker   }
102*da0073e9SAndroid Build Coastguard Worker 
103*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS_PYBIND                                 \
104*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
105*da0073e9SAndroid Build Coastguard Worker   catch (...) {                                                     \
106*da0073e9SAndroid Build Coastguard Worker     __enforce_warning_buffer.set_in_exception();                    \
107*da0073e9SAndroid Build Coastguard Worker     throw;                                                          \
108*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
109*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
110*da0073e9SAndroid Build Coastguard Worker   catch (py::error_already_set & e) {                               \
111*da0073e9SAndroid Build Coastguard Worker     throw;                                                          \
112*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
113*da0073e9SAndroid Build Coastguard Worker   catch (py::builtin_exception & e) {                               \
114*da0073e9SAndroid Build Coastguard Worker     throw;                                                          \
115*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
116*da0073e9SAndroid Build Coastguard Worker   catch (torch::jit::JITException & e) {                            \
117*da0073e9SAndroid Build Coastguard Worker     throw;                                                          \
118*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
119*da0073e9SAndroid Build Coastguard Worker   catch (const std::exception& e) {                                 \
120*da0073e9SAndroid Build Coastguard Worker     torch::translate_exception_to_python(std::current_exception()); \
121*da0073e9SAndroid Build Coastguard Worker     throw py::error_already_set();                                  \
122*da0073e9SAndroid Build Coastguard Worker   }
123*da0073e9SAndroid Build Coastguard Worker 
124*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS_RET(retval)                            \
125*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
126*da0073e9SAndroid Build Coastguard Worker   catch (...) {                                                     \
127*da0073e9SAndroid Build Coastguard Worker     __enforce_warning_buffer.set_in_exception();                    \
128*da0073e9SAndroid Build Coastguard Worker     throw;                                                          \
129*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
130*da0073e9SAndroid Build Coastguard Worker   }                                                                 \
131*da0073e9SAndroid Build Coastguard Worker   catch (const std::exception& e) {                                 \
132*da0073e9SAndroid Build Coastguard Worker     torch::translate_exception_to_python(std::current_exception()); \
133*da0073e9SAndroid Build Coastguard Worker     return retval;                                                  \
134*da0073e9SAndroid Build Coastguard Worker   }
135*da0073e9SAndroid Build Coastguard Worker 
136*da0073e9SAndroid Build Coastguard Worker #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
137*da0073e9SAndroid Build Coastguard Worker 
138*da0073e9SAndroid Build Coastguard Worker extern PyObject *THPException_FatalError, *THPException_LinAlgError,
139*da0073e9SAndroid Build Coastguard Worker     *THPException_OutOfMemoryError, *THPException_DistError,
140*da0073e9SAndroid Build Coastguard Worker     *THPException_DistBackendError, *THPException_DistNetworkError,
141*da0073e9SAndroid Build Coastguard Worker     *THPException_DistStoreError;
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker // Throwing this exception means that the python error flags have been already
144*da0073e9SAndroid Build Coastguard Worker // set and control should be immediately returned to the interpreter.
145*da0073e9SAndroid Build Coastguard Worker struct python_error : public std::exception {
146*da0073e9SAndroid Build Coastguard Worker   python_error() = default;
147*da0073e9SAndroid Build Coastguard Worker 
python_errorpython_error148*da0073e9SAndroid Build Coastguard Worker   python_error(const python_error& other)
149*da0073e9SAndroid Build Coastguard Worker       : type(other.type),
150*da0073e9SAndroid Build Coastguard Worker         value(other.value),
151*da0073e9SAndroid Build Coastguard Worker         traceback(other.traceback),
152*da0073e9SAndroid Build Coastguard Worker         message(other.message) {
153*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_acquire gil;
154*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(type);
155*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(value);
156*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(traceback);
157*da0073e9SAndroid Build Coastguard Worker   }
158*da0073e9SAndroid Build Coastguard Worker 
python_errorpython_error159*da0073e9SAndroid Build Coastguard Worker   python_error(python_error&& other) noexcept
160*da0073e9SAndroid Build Coastguard Worker       : type(other.type),
161*da0073e9SAndroid Build Coastguard Worker         value(other.value),
162*da0073e9SAndroid Build Coastguard Worker         traceback(other.traceback),
163*da0073e9SAndroid Build Coastguard Worker         message(std::move(other.message)) {
164*da0073e9SAndroid Build Coastguard Worker     other.type = nullptr;
165*da0073e9SAndroid Build Coastguard Worker     other.value = nullptr;
166*da0073e9SAndroid Build Coastguard Worker     other.traceback = nullptr;
167*da0073e9SAndroid Build Coastguard Worker   }
168*da0073e9SAndroid Build Coastguard Worker 
169*da0073e9SAndroid Build Coastguard Worker   python_error& operator=(const python_error& other) = delete;
170*da0073e9SAndroid Build Coastguard Worker   python_error& operator=(python_error&& other) = delete;
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
~python_errorpython_error173*da0073e9SAndroid Build Coastguard Worker   ~python_error() override {
174*da0073e9SAndroid Build Coastguard Worker     if (type || value || traceback) {
175*da0073e9SAndroid Build Coastguard Worker       pybind11::gil_scoped_acquire gil;
176*da0073e9SAndroid Build Coastguard Worker       Py_XDECREF(type);
177*da0073e9SAndroid Build Coastguard Worker       Py_XDECREF(value);
178*da0073e9SAndroid Build Coastguard Worker       Py_XDECREF(traceback);
179*da0073e9SAndroid Build Coastguard Worker     }
180*da0073e9SAndroid Build Coastguard Worker   }
181*da0073e9SAndroid Build Coastguard Worker 
whatpython_error182*da0073e9SAndroid Build Coastguard Worker   const char* what() const noexcept override {
183*da0073e9SAndroid Build Coastguard Worker     return message.c_str();
184*da0073e9SAndroid Build Coastguard Worker   }
185*da0073e9SAndroid Build Coastguard Worker 
build_messagepython_error186*da0073e9SAndroid Build Coastguard Worker   void build_message() {
187*da0073e9SAndroid Build Coastguard Worker     // Ensure we have the GIL.
188*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_acquire gil;
189*da0073e9SAndroid Build Coastguard Worker 
190*da0073e9SAndroid Build Coastguard Worker     // No errors should be set when we enter the function since PyErr_Fetch
191*da0073e9SAndroid Build Coastguard Worker     // clears the error indicator.
192*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
193*da0073e9SAndroid Build Coastguard Worker 
194*da0073e9SAndroid Build Coastguard Worker     // Default message.
195*da0073e9SAndroid Build Coastguard Worker     message = "python_error";
196*da0073e9SAndroid Build Coastguard Worker 
197*da0073e9SAndroid Build Coastguard Worker     // Try to retrieve the error message from the value.
198*da0073e9SAndroid Build Coastguard Worker     if (value != nullptr) {
199*da0073e9SAndroid Build Coastguard Worker       // Reference count should not be zero.
200*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker       PyObject* pyStr = PyObject_Str(value);
203*da0073e9SAndroid Build Coastguard Worker       if (pyStr != nullptr) {
204*da0073e9SAndroid Build Coastguard Worker         PyObject* encodedString =
205*da0073e9SAndroid Build Coastguard Worker             PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
206*da0073e9SAndroid Build Coastguard Worker         if (encodedString != nullptr) {
207*da0073e9SAndroid Build Coastguard Worker           char* bytes = PyBytes_AS_STRING(encodedString);
208*da0073e9SAndroid Build Coastguard Worker           if (bytes != nullptr) {
209*da0073e9SAndroid Build Coastguard Worker             // Set the message.
210*da0073e9SAndroid Build Coastguard Worker             message = std::string(bytes);
211*da0073e9SAndroid Build Coastguard Worker           }
212*da0073e9SAndroid Build Coastguard Worker           Py_XDECREF(encodedString);
213*da0073e9SAndroid Build Coastguard Worker         }
214*da0073e9SAndroid Build Coastguard Worker         Py_XDECREF(pyStr);
215*da0073e9SAndroid Build Coastguard Worker       }
216*da0073e9SAndroid Build Coastguard Worker     }
217*da0073e9SAndroid Build Coastguard Worker 
218*da0073e9SAndroid Build Coastguard Worker     // Clear any errors since we don't want to propagate errors for functions
219*da0073e9SAndroid Build Coastguard Worker     // that are trying to build a string for the error message.
220*da0073e9SAndroid Build Coastguard Worker     PyErr_Clear();
221*da0073e9SAndroid Build Coastguard Worker   }
222*da0073e9SAndroid Build Coastguard Worker 
223*da0073e9SAndroid Build Coastguard Worker   /** Saves the exception so that it can be re-thrown on a different thread */
persistpython_error224*da0073e9SAndroid Build Coastguard Worker   inline void persist() {
225*da0073e9SAndroid Build Coastguard Worker     if (type)
226*da0073e9SAndroid Build Coastguard Worker       return; // Don't overwrite exceptions
227*da0073e9SAndroid Build Coastguard Worker     // PyErr_Fetch overwrites the pointers
228*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_acquire gil;
229*da0073e9SAndroid Build Coastguard Worker     Py_XDECREF(type);
230*da0073e9SAndroid Build Coastguard Worker     Py_XDECREF(value);
231*da0073e9SAndroid Build Coastguard Worker     Py_XDECREF(traceback);
232*da0073e9SAndroid Build Coastguard Worker     PyErr_Fetch(&type, &value, &traceback);
233*da0073e9SAndroid Build Coastguard Worker     build_message();
234*da0073e9SAndroid Build Coastguard Worker   }
235*da0073e9SAndroid Build Coastguard Worker 
236*da0073e9SAndroid Build Coastguard Worker   /** Sets the current Python error from this exception */
restorepython_error237*da0073e9SAndroid Build Coastguard Worker   inline void restore() {
238*da0073e9SAndroid Build Coastguard Worker     if (!type)
239*da0073e9SAndroid Build Coastguard Worker       return;
240*da0073e9SAndroid Build Coastguard Worker     // PyErr_Restore steals references
241*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_acquire gil;
242*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(type);
243*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(value);
244*da0073e9SAndroid Build Coastguard Worker     Py_XINCREF(traceback);
245*da0073e9SAndroid Build Coastguard Worker     PyErr_Restore(type, value, traceback);
246*da0073e9SAndroid Build Coastguard Worker   }
247*da0073e9SAndroid Build Coastguard Worker 
248*da0073e9SAndroid Build Coastguard Worker   PyObject* type{nullptr};
249*da0073e9SAndroid Build Coastguard Worker   PyObject* value{nullptr};
250*da0073e9SAndroid Build Coastguard Worker   PyObject* traceback{nullptr};
251*da0073e9SAndroid Build Coastguard Worker 
252*da0073e9SAndroid Build Coastguard Worker   // Message to return to the user when 'what()' is invoked.
253*da0073e9SAndroid Build Coastguard Worker   std::string message;
254*da0073e9SAndroid Build Coastguard Worker };
255*da0073e9SAndroid Build Coastguard Worker 
256*da0073e9SAndroid Build Coastguard Worker bool THPException_init(PyObject* module);
257*da0073e9SAndroid Build Coastguard Worker 
258*da0073e9SAndroid Build Coastguard Worker namespace torch {
259*da0073e9SAndroid Build Coastguard Worker 
260*da0073e9SAndroid Build Coastguard Worker // Set python current exception from a C++ exception
261*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
262*da0073e9SAndroid Build Coastguard Worker 
263*da0073e9SAndroid Build Coastguard Worker TORCH_PYTHON_API std::string processErrorMsg(std::string str);
264*da0073e9SAndroid Build Coastguard Worker 
265*da0073e9SAndroid Build Coastguard Worker // Abstract base class for exceptions which translate to specific Python types
266*da0073e9SAndroid Build Coastguard Worker struct PyTorchError : public std::exception {
267*da0073e9SAndroid Build Coastguard Worker   PyTorchError() = default;
PyTorchErrorPyTorchError268*da0073e9SAndroid Build Coastguard Worker   PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
269*da0073e9SAndroid Build Coastguard Worker   virtual PyObject* python_type() = 0;
whatPyTorchError270*da0073e9SAndroid Build Coastguard Worker   const char* what() const noexcept override {
271*da0073e9SAndroid Build Coastguard Worker     return msg.c_str();
272*da0073e9SAndroid Build Coastguard Worker   }
273*da0073e9SAndroid Build Coastguard Worker   std::string msg;
274*da0073e9SAndroid Build Coastguard Worker };
275*da0073e9SAndroid Build Coastguard Worker 
276*da0073e9SAndroid Build Coastguard Worker // Declare a printf-like function on gcc & clang
277*da0073e9SAndroid Build Coastguard Worker // The compiler can then warn on invalid format specifiers
278*da0073e9SAndroid Build Coastguard Worker #ifdef __GNUC__
279*da0073e9SAndroid Build Coastguard Worker #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
280*da0073e9SAndroid Build Coastguard Worker   __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
281*da0073e9SAndroid Build Coastguard Worker #else
282*da0073e9SAndroid Build Coastguard Worker #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
283*da0073e9SAndroid Build Coastguard Worker #endif
284*da0073e9SAndroid Build Coastguard Worker 
285*da0073e9SAndroid Build Coastguard Worker // Translates to Python TypeError
286*da0073e9SAndroid Build Coastguard Worker struct TypeError : public PyTorchError {
287*da0073e9SAndroid Build Coastguard Worker   using PyTorchError::PyTorchError;
288*da0073e9SAndroid Build Coastguard Worker   TORCH_PYTHON_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeTypeError289*da0073e9SAndroid Build Coastguard Worker   PyObject* python_type() override {
290*da0073e9SAndroid Build Coastguard Worker     return PyExc_TypeError;
291*da0073e9SAndroid Build Coastguard Worker   }
292*da0073e9SAndroid Build Coastguard Worker };
293*da0073e9SAndroid Build Coastguard Worker 
294*da0073e9SAndroid Build Coastguard Worker // Translates to Python AttributeError
295*da0073e9SAndroid Build Coastguard Worker struct AttributeError : public PyTorchError {
296*da0073e9SAndroid Build Coastguard Worker   AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
python_typeAttributeError297*da0073e9SAndroid Build Coastguard Worker   PyObject* python_type() override {
298*da0073e9SAndroid Build Coastguard Worker     return PyExc_AttributeError;
299*da0073e9SAndroid Build Coastguard Worker   }
300*da0073e9SAndroid Build Coastguard Worker };
301*da0073e9SAndroid Build Coastguard Worker 
302*da0073e9SAndroid Build Coastguard Worker // ATen warning handler for Python
303*da0073e9SAndroid Build Coastguard Worker struct PyWarningHandler {
304*da0073e9SAndroid Build Coastguard Worker   // Move actual handler into a separate class with a noexcept
305*da0073e9SAndroid Build Coastguard Worker   // destructor. Otherwise, we need to force all WarningHandler
306*da0073e9SAndroid Build Coastguard Worker   // subclasses to have a noexcept(false) destructor.
307*da0073e9SAndroid Build Coastguard Worker   struct InternalHandler : at::WarningHandler {
308*da0073e9SAndroid Build Coastguard Worker     ~InternalHandler() override = default;
309*da0073e9SAndroid Build Coastguard Worker     void process(const c10::Warning& warning) override;
310*da0073e9SAndroid Build Coastguard Worker 
311*da0073e9SAndroid Build Coastguard Worker     std::vector<c10::Warning> warning_buffer_;
312*da0073e9SAndroid Build Coastguard Worker   };
313*da0073e9SAndroid Build Coastguard Worker 
314*da0073e9SAndroid Build Coastguard Worker  public:
315*da0073e9SAndroid Build Coastguard Worker   /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
316*da0073e9SAndroid Build Coastguard Worker   TORCH_PYTHON_API PyWarningHandler() noexcept(true);
317*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-exception-escape)
318*da0073e9SAndroid Build Coastguard Worker   TORCH_PYTHON_API ~PyWarningHandler() noexcept(false);
319*da0073e9SAndroid Build Coastguard Worker 
320*da0073e9SAndroid Build Coastguard Worker   /** Call if an exception has been thrown
321*da0073e9SAndroid Build Coastguard Worker 
322*da0073e9SAndroid Build Coastguard Worker    *  Necessary to determine if it is safe to throw from the desctructor since
323*da0073e9SAndroid Build Coastguard Worker    *  std::uncaught_exception is buggy on some platforms and generally
324*da0073e9SAndroid Build Coastguard Worker    *  unreliable across dynamic library calls.
325*da0073e9SAndroid Build Coastguard Worker    */
set_in_exceptionPyWarningHandler326*da0073e9SAndroid Build Coastguard Worker   void set_in_exception() {
327*da0073e9SAndroid Build Coastguard Worker     in_exception_ = true;
328*da0073e9SAndroid Build Coastguard Worker   }
329*da0073e9SAndroid Build Coastguard Worker 
330*da0073e9SAndroid Build Coastguard Worker  private:
331*da0073e9SAndroid Build Coastguard Worker   InternalHandler internal_handler_;
332*da0073e9SAndroid Build Coastguard Worker   at::WarningHandler* prev_handler_;
333*da0073e9SAndroid Build Coastguard Worker   bool in_exception_;
334*da0073e9SAndroid Build Coastguard Worker };
335*da0073e9SAndroid Build Coastguard Worker 
336*da0073e9SAndroid Build Coastguard Worker namespace detail {
337*da0073e9SAndroid Build Coastguard Worker 
338*da0073e9SAndroid Build Coastguard Worker struct noop_gil_scoped_release {
339*da0073e9SAndroid Build Coastguard Worker   // user-defined constructor (i.e. not defaulted) to avoid
340*da0073e9SAndroid Build Coastguard Worker   // unused-variable warnings at usage sites of this class
noop_gil_scoped_releasenoop_gil_scoped_release341*da0073e9SAndroid Build Coastguard Worker   noop_gil_scoped_release() {}
342*da0073e9SAndroid Build Coastguard Worker };
343*da0073e9SAndroid Build Coastguard Worker 
344*da0073e9SAndroid Build Coastguard Worker template <bool release_gil>
345*da0073e9SAndroid Build Coastguard Worker using conditional_gil_scoped_release = std::conditional_t<
346*da0073e9SAndroid Build Coastguard Worker     release_gil,
347*da0073e9SAndroid Build Coastguard Worker     pybind11::gil_scoped_release,
348*da0073e9SAndroid Build Coastguard Worker     noop_gil_scoped_release>;
349*da0073e9SAndroid Build Coastguard Worker 
350*da0073e9SAndroid Build Coastguard Worker template <typename Func, size_t i>
351*da0073e9SAndroid Build Coastguard Worker using Arg = typename invoke_traits<Func>::template arg<i>::type;
352*da0073e9SAndroid Build Coastguard Worker 
353*da0073e9SAndroid Build Coastguard Worker template <typename Func, size_t... Is, bool release_gil>
wrap_pybind_function_impl_(Func && f,std::index_sequence<Is...>,std::bool_constant<release_gil>)354*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function_impl_(
355*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
356*da0073e9SAndroid Build Coastguard Worker     Func&& f,
357*da0073e9SAndroid Build Coastguard Worker     std::index_sequence<Is...>,
358*da0073e9SAndroid Build Coastguard Worker     std::bool_constant<release_gil>) {
359*da0073e9SAndroid Build Coastguard Worker   namespace py = pybind11;
360*da0073e9SAndroid Build Coastguard Worker 
361*da0073e9SAndroid Build Coastguard Worker   // f=f is needed to handle function references on older compilers
362*da0073e9SAndroid Build Coastguard Worker   return [f = std::forward<Func>(f)](Arg<Func, Is>... args) {
363*da0073e9SAndroid Build Coastguard Worker     HANDLE_TH_ERRORS
364*da0073e9SAndroid Build Coastguard Worker     conditional_gil_scoped_release<release_gil> no_gil;
365*da0073e9SAndroid Build Coastguard Worker     return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
366*da0073e9SAndroid Build Coastguard Worker     END_HANDLE_TH_ERRORS_PYBIND
367*da0073e9SAndroid Build Coastguard Worker   };
368*da0073e9SAndroid Build Coastguard Worker }
369*da0073e9SAndroid Build Coastguard Worker } // namespace detail
370*da0073e9SAndroid Build Coastguard Worker 
371*da0073e9SAndroid Build Coastguard Worker // Wrap a function with TH error and warning handling.
372*da0073e9SAndroid Build Coastguard Worker // Returns a function object suitable for registering with pybind11.
373*da0073e9SAndroid Build Coastguard Worker template <typename Func>
wrap_pybind_function(Func && f)374*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function(Func&& f) {
375*da0073e9SAndroid Build Coastguard Worker   using traits = invoke_traits<Func>;
376*da0073e9SAndroid Build Coastguard Worker   return torch::detail::wrap_pybind_function_impl_(
377*da0073e9SAndroid Build Coastguard Worker       std::forward<Func>(f),
378*da0073e9SAndroid Build Coastguard Worker       std::make_index_sequence<traits::arity>{},
379*da0073e9SAndroid Build Coastguard Worker       std::false_type{});
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker 
382*da0073e9SAndroid Build Coastguard Worker // Wrap a function with TH error, warning handling and releases the GIL.
383*da0073e9SAndroid Build Coastguard Worker // Returns a function object suitable for registering with pybind11.
384*da0073e9SAndroid Build Coastguard Worker template <typename Func>
wrap_pybind_function_no_gil(Func && f)385*da0073e9SAndroid Build Coastguard Worker auto wrap_pybind_function_no_gil(Func&& f) {
386*da0073e9SAndroid Build Coastguard Worker   using traits = invoke_traits<Func>;
387*da0073e9SAndroid Build Coastguard Worker   return torch::detail::wrap_pybind_function_impl_(
388*da0073e9SAndroid Build Coastguard Worker       std::forward<Func>(f),
389*da0073e9SAndroid Build Coastguard Worker       std::make_index_sequence<traits::arity>{},
390*da0073e9SAndroid Build Coastguard Worker       std::true_type{});
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker 
393*da0073e9SAndroid Build Coastguard Worker } // namespace torch
394