1 #include <torch/csrc/Exceptions.h>
2 #include <torch/csrc/python_headers.h>
3
4 #include <array>
5 #include <cstdarg>
6 #include <exception>
7 #include <utility>
8
9 #include <fmt/format.h>
10 #include <torch/csrc/THP.h>
11
12 #include <c10/util/StringUtil.h>
13
14 PyObject *THPException_FatalError, *THPException_LinAlgError,
15 *THPException_OutOfMemoryError, *THPException_DistError,
16 *THPException_DistBackendError, *THPException_DistNetworkError,
17 *THPException_DistStoreError;
18
19 #define ASSERT_TRUE(cond) \
20 if (!(cond)) \
21 return false
THPException_init(PyObject * module)22 bool THPException_init(PyObject* module) {
23 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
24 ASSERT_TRUE(
25 THPException_FatalError =
26 PyErr_NewException("torch.FatalError", nullptr, nullptr));
27 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
28 ASSERT_TRUE(
29 PyModule_AddObject(module, "FatalError", THPException_FatalError) == 0);
30
31 // Set the doc string here since _add_docstr throws malloc errors if tp_doc is
32 // modified for an error class.
33 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
34 ASSERT_TRUE(
35 THPException_LinAlgError = PyErr_NewExceptionWithDoc(
36 "torch._C._LinAlgError",
37 "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \
38 For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \
39 a matrix is not invertible.\n \
40 \n\
41 Example:\n \
42 >>> # xdoctest: +REQUIRES(env:TORCH_DOCKTEST_LAPACK)\n \
43 >>> matrix = torch.eye(3, 3)\n \
44 >>> matrix[-1, -1] = 0\n \
45 >>> matrix\n \
46 tensor([[1., 0., 0.],\n \
47 [0., 1., 0.],\n \
48 [0., 0., 0.]])\n \
49 >>> torch.linalg.inv(matrix)\n \
50 Traceback (most recent call last):\n \
51 File \"<stdin>\", line 1, in <module>\n \
52 torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \
53 could not be completed because the input matrix is singular.",
54 PyExc_RuntimeError,
55 nullptr));
56 ASSERT_TRUE(
57 PyModule_AddObject(module, "_LinAlgError", THPException_LinAlgError) ==
58 0);
59
60 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
61 ASSERT_TRUE(
62 THPException_OutOfMemoryError = PyErr_NewExceptionWithDoc(
63 "torch.OutOfMemoryError",
64 "Exception raised when device is out of memory",
65 PyExc_RuntimeError,
66 nullptr));
67 PyTypeObject* type = (PyTypeObject*)THPException_OutOfMemoryError;
68 type->tp_name = "torch.OutOfMemoryError";
69 ASSERT_TRUE(
70 PyModule_AddObject(
71 module, "OutOfMemoryError", THPException_OutOfMemoryError) == 0);
72
73 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
74 ASSERT_TRUE(
75 THPException_DistError = PyErr_NewExceptionWithDoc(
76 "torch.distributed.DistError",
77 "Exception raised when an error occurs in the distributed library",
78 PyExc_RuntimeError,
79 nullptr));
80 ASSERT_TRUE(
81 PyModule_AddObject(module, "_DistError", THPException_DistError) == 0);
82
83 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
84 ASSERT_TRUE(
85 THPException_DistBackendError = PyErr_NewExceptionWithDoc(
86 "torch.distributed.DistBackendError",
87 "Exception raised when a backend error occurs in distributed",
88 THPException_DistError,
89 nullptr));
90 ASSERT_TRUE(
91 PyModule_AddObject(
92 module, "_DistBackendError", THPException_DistBackendError) == 0);
93
94 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
95 ASSERT_TRUE(
96 THPException_DistNetworkError = PyErr_NewExceptionWithDoc(
97 "torch.distributed.DistNetworkError",
98 "Exception raised when a network error occurs in distributed",
99 THPException_DistError,
100 nullptr));
101 ASSERT_TRUE(
102 PyModule_AddObject(
103 module, "_DistNetworkError", THPException_DistNetworkError) == 0);
104
105 // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
106 ASSERT_TRUE(
107 THPException_DistStoreError = PyErr_NewExceptionWithDoc(
108 "torch.distributed.DistStoreError",
109 "Exception raised when an error occurs in the distributed store",
110 THPException_DistError,
111 nullptr));
112 ASSERT_TRUE(
113 PyModule_AddObject(
114 module, "_DistStoreError", THPException_DistStoreError) == 0);
115
116 return true;
117 }
118
119 namespace torch {
120
processErrorMsgInplace(std::string & str)121 void processErrorMsgInplace(std::string& str) {
122 // Translate Aten types to their respective pytorch ones
123 constexpr std::array<std::pair<c10::string_view, c10::string_view>, 64>
124 changes{{
125 // TODO: remove torch.(cuda.|)sparse.*Tensor items?
126 {"Variable[SparseCUDAByteType]", "torch.cuda.sparse.ByteTensor"},
127 {"Variable[SparseCUDACharType]", "torch.cuda.sparse.CharTensor"},
128 {"Variable[SparseCUDADoubleType]", "torch.cuda.sparse.DoubleTensor"},
129 {"Variable[SparseCUDAFloatType]", "torch.cuda.sparse.FloatTensor"},
130 {"Variable[SparseCUDAIntType]", "torch.cuda.sparse.IntTensor"},
131 {"Variable[SparseCUDALongType]", "torch.cuda.sparse.LongTensor"},
132 {"Variable[SparseCUDAShortType]", "torch.cuda.sparse.ShortTensor"},
133 {"Variable[SparseCUDAHalfType]", "torch.cuda.sparse.HalfTensor"},
134 {"Variable[SparseCPUByteType]", "torch.sparse.ByteTensor"},
135 {"Variable[SparseCPUCharType]", "torch.sparse.CharTensor"},
136 {"Variable[SparseCPUDoubleType]", "torch.sparse.DoubleTensor"},
137 {"Variable[SparseCPUFloatType]", "torch.sparse.FloatTensor"},
138 {"Variable[SparseCPUIntType]", "torch.sparse.IntTensor"},
139 {"Variable[SparseCPULongType]", "torch.sparse.LongTensor"},
140 {"Variable[SparseCPUShortType]", "torch.sparse.ShortTensor"},
141 {"Variable[SparseCPUHalfType]", "torch.sparse.HalfTensor"},
142 {"Variable[CUDAByteType]", "torch.cuda.ByteTensor"},
143 {"Variable[CUDACharType]", "torch.cuda.CharTensor"},
144 {"Variable[CUDADoubleType]", "torch.cuda.DoubleTensor"},
145 {"Variable[CUDAFloatType]", "torch.cuda.FloatTensor"},
146 {"Variable[CUDAIntType]", "torch.cuda.IntTensor"},
147 {"Variable[CUDALongType]", "torch.cuda.LongTensor"},
148 {"Variable[CUDAShortType]", "torch.cuda.ShortTensor"},
149 {"Variable[CUDAHalfType]", "torch.cuda.HalfTensor"},
150 {"Variable[CPUByteType]", "torch.ByteTensor"},
151 {"Variable[CPUCharType]", "torch.CharTensor"},
152 {"Variable[CPUDoubleType]", "torch.DoubleTensor"},
153 {"Variable[CPUFloatType]", "torch.FloatTensor"},
154 {"Variable[CPUIntType]", "torch.IntTensor"},
155 {"Variable[CPULongType]", "torch.LongTensor"},
156 {"Variable[CPUShortType]", "torch.ShortTensor"},
157 {"Variable[CPUHalfType]", "torch.HalfTensor"},
158 {"SparseCUDAByteType", "torch.cuda.sparse.ByteTensor"},
159 {"SparseCUDACharType", "torch.cuda.sparse.CharTensor"},
160 {"SparseCUDADoubleType", "torch.cuda.sparse.DoubleTensor"},
161 {"SparseCUDAFloatType", "torch.cuda.sparse.FloatTensor"},
162 {"SparseCUDAIntType", "torch.cuda.sparse.IntTensor"},
163 {"SparseCUDALongType", "torch.cuda.sparse.LongTensor"},
164 {"SparseCUDAShortType", "torch.cuda.sparse.ShortTensor"},
165 {"SparseCUDAHalfType", "torch.cuda.sparse.HalfTensor"},
166 {"SparseCPUByteType", "torch.sparse.ByteTensor"},
167 {"SparseCPUCharType", "torch.sparse.CharTensor"},
168 {"SparseCPUDoubleType", "torch.sparse.DoubleTensor"},
169 {"SparseCPUFloatType", "torch.sparse.FloatTensor"},
170 {"SparseCPUIntType", "torch.sparse.IntTensor"},
171 {"SparseCPULongType", "torch.sparse.LongTensor"},
172 {"SparseCPUShortType", "torch.sparse.ShortTensor"},
173 {"SparseCPUHalfType", "torch.sparse.HalfTensor"},
174 {"CUDAByteType", "torch.cuda.ByteTensor"},
175 {"CUDACharType", "torch.cuda.CharTensor"},
176 {"CUDADoubleType", "torch.cuda.DoubleTensor"},
177 {"CUDAFloatType", "torch.cuda.FloatTensor"},
178 {"CUDAIntType", "torch.cuda.IntTensor"},
179 {"CUDALongType", "torch.cuda.LongTensor"},
180 {"CUDAShortType", "torch.cuda.ShortTensor"},
181 {"CUDAHalfType", "torch.cuda.HalfTensor"},
182 {"CPUByteType", "torch.ByteTensor"},
183 {"CPUCharType", "torch.CharTensor"},
184 {"CPUDoubleType", "torch.DoubleTensor"},
185 {"CPUFloatType", "torch.FloatTensor"},
186 {"CPUIntType", "torch.IntTensor"},
187 {"CPULongType", "torch.LongTensor"},
188 {"CPUShortType", "torch.ShortTensor"},
189 {"CPUHalfType", "torch.HalfTensor"},
190 }};
191
192 // Avoid doing any work if no types need translated
193 if (str.find("Type") == str.npos) {
194 return;
195 }
196 for (const auto& it : changes) {
197 c10::ReplaceAll(str, it.first, it.second);
198 }
199 }
200
processErrorMsg(std::string str)201 std::string processErrorMsg(std::string str) {
202 processErrorMsgInplace(str);
203 return str;
204 }
205
formatMessage(const char * format,va_list fmt_args)206 static std::string formatMessage(const char* format, va_list fmt_args) {
207 static const size_t ERROR_BUF_SIZE = 1024;
208 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
209 char error_buf[ERROR_BUF_SIZE];
210 vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args);
211
212 // Ensure that the string is null terminated
213 error_buf[sizeof(error_buf) / sizeof(*error_buf) - 1] = 0;
214
215 return std::string(error_buf);
216 }
217
translate_exception_to_python(const std::exception_ptr & e_ptr)218 void translate_exception_to_python(const std::exception_ptr& e_ptr) {
219 try {
220 TORCH_INTERNAL_ASSERT(
221 e_ptr,
222 "translate_exception_to_python "
223 "called with invalid exception pointer");
224 std::rethrow_exception(e_ptr);
225 }
226 CATCH_ALL_ERRORS(return)
227 }
228
TypeError(const char * format,...)229 TypeError::TypeError(const char* format, ...) {
230 va_list fmt_args{};
231 va_start(fmt_args, format);
232 msg = formatMessage(format, fmt_args);
233 va_end(fmt_args);
234 }
235
AttributeError(const char * format,...)236 AttributeError::AttributeError(const char* format, ...) {
237 va_list fmt_args{};
238 va_start(fmt_args, format);
239 msg = formatMessage(format, fmt_args);
240 va_end(fmt_args);
241 }
242
process(const c10::Warning & warning)243 void PyWarningHandler::InternalHandler::process(const c10::Warning& warning) {
244 warning_buffer_.push_back(warning);
245 }
246
PyWarningHandler()247 PyWarningHandler::PyWarningHandler() noexcept(true)
248 : prev_handler_(c10::WarningUtils::get_warning_handler()),
249 in_exception_(false) {
250 c10::WarningUtils::set_warning_handler(&internal_handler_);
251 }
252
253 // Get the Python warning type for a warning
map_warning_to_python_type(const c10::Warning & warning)254 PyObject* map_warning_to_python_type(const c10::Warning& warning) {
255 struct Visitor {
256 PyObject* operator()(const c10::UserWarning&) const {
257 return PyExc_UserWarning;
258 }
259 PyObject* operator()(const c10::DeprecationWarning&) const {
260 return PyExc_DeprecationWarning;
261 }
262 };
263 return std::visit(Visitor(), warning.type());
264 }
265
266 /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
267 /// NOLINTNEXTLINE(bugprone-exception-escape)
~PyWarningHandler()268 PyWarningHandler::~PyWarningHandler() noexcept(false) {
269 c10::WarningUtils::set_warning_handler(prev_handler_);
270 auto& warning_buffer = internal_handler_.warning_buffer_;
271
272 if (!warning_buffer.empty()) {
273 PyObject *type = nullptr, *value = nullptr, *traceback = nullptr;
274 pybind11::gil_scoped_acquire gil;
275 auto result = 0;
276 if (in_exception_) {
277 // This (combined with PyErr_Restore below) also works when no python
278 // error has been set yet
279 PyErr_Fetch(&type, &value, &traceback);
280 }
281 for (const auto& warning : warning_buffer) {
282 auto source_location = warning.source_location();
283 auto msg = warning.msg();
284 processErrorMsgInplace(msg);
285 if (source_location.file == nullptr) {
286 result =
287 PyErr_WarnEx(map_warning_to_python_type(warning), msg.c_str(), 1);
288 } else if (warning.verbatim()) {
289 // Sets the source location from the warning
290 // Note: PyErr_WarnExplicit will disregard Python's warning filter
291 // and always appear. This is in contrast to PyErr_WarnEx,
292 // which respects the warning filter.
293 result = PyErr_WarnExplicit(
294 /*category=*/map_warning_to_python_type(warning),
295 /*message=*/msg.c_str(),
296 /*filename=*/source_location.file,
297 /*lineno=*/static_cast<int>(source_location.line),
298 /*module=*/nullptr,
299 /*registry=*/nullptr);
300 } else {
301 // Lets Python set the source location and puts the C++ warning
302 // location into the message.
303 auto buf = fmt::format(
304 "{} (Triggered internally at {}:{}.)",
305 msg,
306 source_location.file,
307 source_location.line);
308 result =
309 PyErr_WarnEx(map_warning_to_python_type(warning), buf.c_str(), 1);
310 }
311 if (result < 0) {
312 if (in_exception_) {
313 // PyErr_Print prints the traceback to sys.stderr and
314 // clears the error indicator
315 PyErr_Print();
316 } else {
317 break;
318 }
319 }
320 }
321 warning_buffer.clear();
322 if ((result < 0) && (!in_exception_)) {
323 /// A warning raised an error, we need to force the parent
324 /// function to return an error code.
325 throw python_error();
326 }
327 if (in_exception_) {
328 PyErr_Restore(type, value, traceback);
329 }
330 }
331 }
332
333 } // namespace torch
334