xref: /aosp_15_r20/external/pytorch/torch/csrc/Exceptions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Exceptions.h>
2 #include <torch/csrc/python_headers.h>
3 
4 #include <array>
5 #include <cstdarg>
6 #include <exception>
7 #include <utility>
8 
9 #include <fmt/format.h>
10 #include <torch/csrc/THP.h>
11 
12 #include <c10/util/StringUtil.h>
13 
14 PyObject *THPException_FatalError, *THPException_LinAlgError,
15     *THPException_OutOfMemoryError, *THPException_DistError,
16     *THPException_DistBackendError, *THPException_DistNetworkError,
17     *THPException_DistStoreError;
18 
19 #define ASSERT_TRUE(cond) \
20   if (!(cond))            \
21   return false
THPException_init(PyObject * module)22 bool THPException_init(PyObject* module) {
23   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
24   ASSERT_TRUE(
25       THPException_FatalError =
26           PyErr_NewException("torch.FatalError", nullptr, nullptr));
27   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
28   ASSERT_TRUE(
29       PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0);
30 
31   // Set the doc string here since _add_docstr throws malloc errors if tp_doc is
32   // modified for an error class.
33   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
34   ASSERT_TRUE(
35       THPException_LinAlgError = PyErr_NewExceptionWithDoc(
36           "torch._C._LinAlgError",
37           "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \
38 For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \
39 a matrix is not invertible.\n \
40 \n\
41 Example:\n \
42 >>> # xdoctest: +REQUIRES(env:TORCH_DOCKTEST_LAPACK)\n \
43 >>> matrix = torch.eye(3, 3)\n \
44 >>> matrix[-1, -1] = 0\n \
45 >>> matrix\n \
46     tensor([[1., 0., 0.],\n \
47             [0., 1., 0.],\n \
48             [0., 0., 0.]])\n \
49 >>> torch.linalg.inv(matrix)\n \
50 Traceback (most recent call last):\n \
51 File \"<stdin>\", line 1, in <module>\n \
52 torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \
53 could not be completed because the input matrix is singular.",
54           PyExc_RuntimeError,
55           nullptr));
56   ASSERT_TRUE(
57       PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) ==
58       0);
59 
60   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
61   ASSERT_TRUE(
62       THPException_OutOfMemoryError = PyErr_NewExceptionWithDoc(
63           "torch.OutOfMemoryError",
64           "Exception raised when device is out of memory",
65           PyExc_RuntimeError,
66           nullptr));
67   PyTypeObject* type = (PyTypeObject*)THPException_OutOfMemoryError;
68   type->tp_name = "torch.OutOfMemoryError";
69   ASSERT_TRUE(
70       PyModule_AddObject(
71           module, "OutOfMemoryError", THPException_OutOfMemoryError) == 0);
72 
73   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
74   ASSERT_TRUE(
75       THPException_DistError = PyErr_NewExceptionWithDoc(
76           "torch.distributed.DistError",
77           "Exception raised when an error occurs in the distributed library",
78           PyExc_RuntimeError,
79           nullptr));
80   ASSERT_TRUE(
81       PyModule_AddObject(module, "_DistError", THPException_DistError) == 0);
82 
83   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
84   ASSERT_TRUE(
85       THPException_DistBackendError = PyErr_NewExceptionWithDoc(
86           "torch.distributed.DistBackendError",
87           "Exception raised when a backend error occurs in distributed",
88           THPException_DistError,
89           nullptr));
90   ASSERT_TRUE(
91       PyModule_AddObject(
92           module, "_DistBackendError", THPException_DistBackendError) == 0);
93 
94   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
95   ASSERT_TRUE(
96       THPException_DistNetworkError = PyErr_NewExceptionWithDoc(
97           "torch.distributed.DistNetworkError",
98           "Exception raised when a network error occurs in distributed",
99           THPException_DistError,
100           nullptr));
101   ASSERT_TRUE(
102       PyModule_AddObject(
103           module, "_DistNetworkError", THPException_DistNetworkError) == 0);
104 
105   // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
106   ASSERT_TRUE(
107       THPException_DistStoreError = PyErr_NewExceptionWithDoc(
108           "torch.distributed.DistStoreError",
109           "Exception raised when an error occurs in the distributed store",
110           THPException_DistError,
111           nullptr));
112   ASSERT_TRUE(
113       PyModule_AddObject(
114           module, "_DistStoreError", THPException_DistStoreError) == 0);
115 
116   return true;
117 }
118 
119 namespace torch {
120 
processErrorMsgInplace(std::string & str)121 void processErrorMsgInplace(std::string& str) {
122   // Translate Aten types to their respective pytorch ones
123   constexpr std::array<std::pair<c10::string_view, c10::string_view>, 64>
124       changes{{
125           // TODO: remove torch.(cuda.|)sparse.*Tensor items?
126           {"Variable[SparseCUDAByteType]", "torch.cuda.sparse.ByteTensor"},
127           {"Variable[SparseCUDACharType]", "torch.cuda.sparse.CharTensor"},
128           {"Variable[SparseCUDADoubleType]", "torch.cuda.sparse.DoubleTensor"},
129           {"Variable[SparseCUDAFloatType]", "torch.cuda.sparse.FloatTensor"},
130           {"Variable[SparseCUDAIntType]", "torch.cuda.sparse.IntTensor"},
131           {"Variable[SparseCUDALongType]", "torch.cuda.sparse.LongTensor"},
132           {"Variable[SparseCUDAShortType]", "torch.cuda.sparse.ShortTensor"},
133           {"Variable[SparseCUDAHalfType]", "torch.cuda.sparse.HalfTensor"},
134           {"Variable[SparseCPUByteType]", "torch.sparse.ByteTensor"},
135           {"Variable[SparseCPUCharType]", "torch.sparse.CharTensor"},
136           {"Variable[SparseCPUDoubleType]", "torch.sparse.DoubleTensor"},
137           {"Variable[SparseCPUFloatType]", "torch.sparse.FloatTensor"},
138           {"Variable[SparseCPUIntType]", "torch.sparse.IntTensor"},
139           {"Variable[SparseCPULongType]", "torch.sparse.LongTensor"},
140           {"Variable[SparseCPUShortType]", "torch.sparse.ShortTensor"},
141           {"Variable[SparseCPUHalfType]", "torch.sparse.HalfTensor"},
142           {"Variable[CUDAByteType]", "torch.cuda.ByteTensor"},
143           {"Variable[CUDACharType]", "torch.cuda.CharTensor"},
144           {"Variable[CUDADoubleType]", "torch.cuda.DoubleTensor"},
145           {"Variable[CUDAFloatType]", "torch.cuda.FloatTensor"},
146           {"Variable[CUDAIntType]", "torch.cuda.IntTensor"},
147           {"Variable[CUDALongType]", "torch.cuda.LongTensor"},
148           {"Variable[CUDAShortType]", "torch.cuda.ShortTensor"},
149           {"Variable[CUDAHalfType]", "torch.cuda.HalfTensor"},
150           {"Variable[CPUByteType]", "torch.ByteTensor"},
151           {"Variable[CPUCharType]", "torch.CharTensor"},
152           {"Variable[CPUDoubleType]", "torch.DoubleTensor"},
153           {"Variable[CPUFloatType]", "torch.FloatTensor"},
154           {"Variable[CPUIntType]", "torch.IntTensor"},
155           {"Variable[CPULongType]", "torch.LongTensor"},
156           {"Variable[CPUShortType]", "torch.ShortTensor"},
157           {"Variable[CPUHalfType]", "torch.HalfTensor"},
158           {"SparseCUDAByteType", "torch.cuda.sparse.ByteTensor"},
159           {"SparseCUDACharType", "torch.cuda.sparse.CharTensor"},
160           {"SparseCUDADoubleType", "torch.cuda.sparse.DoubleTensor"},
161           {"SparseCUDAFloatType", "torch.cuda.sparse.FloatTensor"},
162           {"SparseCUDAIntType", "torch.cuda.sparse.IntTensor"},
163           {"SparseCUDALongType", "torch.cuda.sparse.LongTensor"},
164           {"SparseCUDAShortType", "torch.cuda.sparse.ShortTensor"},
165           {"SparseCUDAHalfType", "torch.cuda.sparse.HalfTensor"},
166           {"SparseCPUByteType", "torch.sparse.ByteTensor"},
167           {"SparseCPUCharType", "torch.sparse.CharTensor"},
168           {"SparseCPUDoubleType", "torch.sparse.DoubleTensor"},
169           {"SparseCPUFloatType", "torch.sparse.FloatTensor"},
170           {"SparseCPUIntType", "torch.sparse.IntTensor"},
171           {"SparseCPULongType", "torch.sparse.LongTensor"},
172           {"SparseCPUShortType", "torch.sparse.ShortTensor"},
173           {"SparseCPUHalfType", "torch.sparse.HalfTensor"},
174           {"CUDAByteType", "torch.cuda.ByteTensor"},
175           {"CUDACharType", "torch.cuda.CharTensor"},
176           {"CUDADoubleType", "torch.cuda.DoubleTensor"},
177           {"CUDAFloatType", "torch.cuda.FloatTensor"},
178           {"CUDAIntType", "torch.cuda.IntTensor"},
179           {"CUDALongType", "torch.cuda.LongTensor"},
180           {"CUDAShortType", "torch.cuda.ShortTensor"},
181           {"CUDAHalfType", "torch.cuda.HalfTensor"},
182           {"CPUByteType", "torch.ByteTensor"},
183           {"CPUCharType", "torch.CharTensor"},
184           {"CPUDoubleType", "torch.DoubleTensor"},
185           {"CPUFloatType", "torch.FloatTensor"},
186           {"CPUIntType", "torch.IntTensor"},
187           {"CPULongType", "torch.LongTensor"},
188           {"CPUShortType", "torch.ShortTensor"},
189           {"CPUHalfType", "torch.HalfTensor"},
190       }};
191 
192   // Avoid doing any work if no types need translated
193   if (str.find("Type") == str.npos) {
194     return;
195   }
196   for (const auto& it : changes) {
197     c10::ReplaceAll(str, it.first, it.second);
198   }
199 }
200 
processErrorMsg(std::string str)201 std::string processErrorMsg(std::string str) {
202   processErrorMsgInplace(str);
203   return str;
204 }
205 
formatMessage(const char * format,va_list fmt_args)206 static std::string formatMessage(const char* format, va_list fmt_args) {
207   static const size_t ERROR_BUF_SIZE = 1024;
208   // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
209   char error_buf[ERROR_BUF_SIZE];
210   vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args);
211 
212   // Ensure that the string is null terminated
213   error_buf[sizeof(error_buf) / sizeof(*error_buf) - 1] = 0;
214 
215   return std::string(error_buf);
216 }
217 
translate_exception_to_python(const std::exception_ptr & e_ptr)218 void translate_exception_to_python(const std::exception_ptr& e_ptr) {
219   try {
220     TORCH_INTERNAL_ASSERT(
221         e_ptr,
222         "translate_exception_to_python "
223         "called with invalid exception pointer");
224     std::rethrow_exception(e_ptr);
225   }
226   CATCH_ALL_ERRORS(return)
227 }
228 
TypeError(const char * format,...)229 TypeError::TypeError(const char* format, ...) {
230   va_list fmt_args{};
231   va_start(fmt_args, format);
232   msg = formatMessage(format, fmt_args);
233   va_end(fmt_args);
234 }
235 
AttributeError(const char * format,...)236 AttributeError::AttributeError(const char* format, ...) {
237   va_list fmt_args{};
238   va_start(fmt_args, format);
239   msg = formatMessage(format, fmt_args);
240   va_end(fmt_args);
241 }
242 
process(const c10::Warning & warning)243 void PyWarningHandler::InternalHandler::process(const c10::Warning& warning) {
244   warning_buffer_.push_back(warning);
245 }
246 
PyWarningHandler()247 PyWarningHandler::PyWarningHandler() noexcept(true)
248     : prev_handler_(c10::WarningUtils::get_warning_handler()),
249       in_exception_(false) {
250   c10::WarningUtils::set_warning_handler(&internal_handler_);
251 }
252 
253 // Get the Python warning type for a warning
map_warning_to_python_type(const c10::Warning & warning)254 PyObject* map_warning_to_python_type(const c10::Warning& warning) {
255   struct Visitor {
256     PyObject* operator()(const c10::UserWarning&) const {
257       return PyExc_UserWarning;
258     }
259     PyObject* operator()(const c10::DeprecationWarning&) const {
260       return PyExc_DeprecationWarning;
261     }
262   };
263   return std::visit(Visitor(), warning.type());
264 }
265 
266 /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
267 /// NOLINTNEXTLINE(bugprone-exception-escape)
~PyWarningHandler()268 PyWarningHandler::~PyWarningHandler() noexcept(false) {
269   c10::WarningUtils::set_warning_handler(prev_handler_);
270   auto& warning_buffer = internal_handler_.warning_buffer_;
271 
272   if (!warning_buffer.empty()) {
273     PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
274     pybind11::gil_scoped_acquire gil;
275     auto result = 0;
276     if (in_exception_) {
277       // This (combined with PyErr_Restore below) also works when no python
278       // error has been set yet
279       PyErr_Fetch(&type, &value, &traceback);
280     }
281     for (const auto& warning : warning_buffer) {
282       auto source_location = warning.source_location();
283       auto msg = warning.msg();
284       processErrorMsgInplace(msg);
285       if (source_location.file == nullptr) {
286         result =
287             PyErr_WarnEx(map_warning_to_python_type(warning), msg.c_str(), 1);
288       } else if (warning.verbatim()) {
289         // Sets the source location from the warning
290         // Note: PyErr_WarnExplicit will disregard Python's warning filter
291         // and always appear. This is in contrast to PyErr_WarnEx,
292         // which respects the warning filter.
293         result = PyErr_WarnExplicit(
294             /*category=*/map_warning_to_python_type(warning),
295             /*message=*/msg.c_str(),
296             /*filename=*/source_location.file,
297             /*lineno=*/static_cast<int>(source_location.line),
298             /*module=*/nullptr,
299             /*registry=*/nullptr);
300       } else {
301         // Lets Python set the source location and puts the C++ warning
302         // location into the message.
303         auto buf = fmt::format(
304             "{} (Triggered internally at {}:{}.)",
305             msg,
306             source_location.file,
307             source_location.line);
308         result =
309             PyErr_WarnEx(map_warning_to_python_type(warning), buf.c_str(), 1);
310       }
311       if (result < 0) {
312         if (in_exception_) {
313           // PyErr_Print prints the traceback to sys.stderr and
314           // clears the error indicator
315           PyErr_Print();
316         } else {
317           break;
318         }
319       }
320     }
321     warning_buffer.clear();
322     if ((result < 0) && (!in_exception_)) {
323       /// A warning raised an error, we need to force the parent
324       /// function to return an error code.
325       throw python_error();
326     }
327     if (in_exception_) {
328       PyErr_Restore(type, value, traceback);
329     }
330   }
331 }
332 
333 } // namespace torch
334