#include // Together with `torch/utils/data/_utils/signal_handling.py`, the following // is an effort to do our best to provide some error message to users when a // worker dies due to error / critical signals. // // See NOTE [ Signal handling in multiprocessing data loading ] for more // details. // TODO: The following don't work on Windows. Specifically, sigaction, waitid // calls, and SIGCHLD handler. Currently, dummy implementations are provided // for Windows. #ifndef _WIN32 #include #include #include #include #include #include #include #include #include using namespace torch; // Critical signal handlers should be registered on worker processes before // doing work. // The handler will raise default handler so that the kill information will be // retrieved from main process. // Python handle is _set_worker_signal_handlers(). #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \ auto _w = \ write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ (void)_w; \ struct sigaction sa {}; \ sa.sa_handler = SIG_DFL; \ sa.sa_flags = 0; \ if (sigemptyset(&sa.sa_mask) != 0 || \ sigaction(SIGNAL, &sa, nullptr) != 0) { \ _exit(EXIT_FAILURE); \ } else { \ raise(SIGNAL); \ } \ } // signal(2) is really not portable. So use sigaction. // http://man7.org/linux/man-pages/man2/signal.2.html static inline void setSignalHandler( int signal, void (*handler)(int, siginfo_t*, void*), struct sigaction* old_sa_ptr) { struct sigaction sa {}; sa.sa_sigaction = handler; sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER; if (sigemptyset(&sa.sa_mask) != 0 || sigaction(signal, &sa, old_sa_ptr) != 0) { std::ostringstream oss; oss << "An error occurred while setting handler for " << strsignal(signal) << "."; throw std::runtime_error(oss.str()); } } SIGNAL_HANDLER( SIGBUS, handler_SIGBUS, "ERROR: Unexpected bus error encountered in worker. " "This might be caused by insufficient shared memory (shm).\n"); SIGNAL_HANDLER( SIGSEGV, handler_SIGSEGV, "ERROR: Unexpected segmentation fault encountered in worker.\n"); SIGNAL_HANDLER( SIGFPE, handler_SIGFPE, "ERROR: Unexpected floating-point exception encountered in worker.\n"); // When an error happened in DataLoader methods and Python starts to exit, the // error trace will keep the loader alive, and Python may kill the children // processes first before deleting the loader object. Then the cleaning up // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError // again, and then it defeats the whole purpose. static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) { if (info->si_pid == getppid()) { _exit(EXIT_SUCCESS); } struct sigaction sa {}; sa.sa_handler = SIG_DFL; sa.sa_flags = 0; if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) { _exit(EXIT_FAILURE); } else { raise(SIGTERM); } } __attribute__((weak)) void setDataLoaderSignalHandlers() {} static PyObject* THPModule_setWorkerSignalHandlers( PyObject* module, PyObject* arg) { HANDLE_TH_ERRORS setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr); setDataLoaderSignalHandlers(); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } static std::map> worker_pids = {}; static PyObject* THPModule_errorIfAnyWorkerFails( PyObject* module, PyObject* noargs) { HANDLE_TH_ERRORS // Only check the pids we care about for (auto& w : worker_pids) { auto& pid_set = w.second; for (auto worker_pid : pid_set) { // Use waitid rather than waitpid so that we can set NOWAIT, and that // Python and other handlers can get whatever info they want about the // child. siginfo_t infop{}; infop.si_pid = 0; auto error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT); // ignore errors and case with no waitable child if (error < 0 || infop.si_pid == 0) continue; if (infop.si_code == CLD_EXITED && infop.si_status != EXIT_SUCCESS) { // exit with error std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") exited " << "unexpectedly with exit code " << infop.si_status << ". " << "Details are lost due to multiprocessing. Rerunning with " << "num_workers=0 may give better error trace."; // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set.clear(); throw std::runtime_error(oss.str()); } else if ( infop.si_code == CLD_KILLED || infop.si_code == CLD_DUMPED) { // killed by signal std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; if (infop.si_status == SIGBUS) { oss << "It is possible that dataloader's workers are out of shared memory. " << "Please try to raise your shared memory limit."; } // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set.clear(); throw std::runtime_error(oss.str()); } } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple // of pids we are interested in. static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) { HANDLE_TH_ERRORS TORCH_CHECK_TYPE( PyTuple_GET_SIZE(args) == 2, "_set_worker_pids expects exactly 2 arguments."); int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); TORCH_CHECK_VALUE( worker_pids.find(key) == worker_pids.end(), "_set_worker_pids should be called only once for each _BaseDataLoaderIter."); PyObject* child_pids = PyTuple_GET_ITEM(args, 1); TORCH_CHECK_TYPE( PyTuple_Check(child_pids), "_set_worker_pids expects a tuple for child_pids, but got ", Py_TYPE(child_pids)->tp_name, "."); std::set pids_set = {}; auto size = PyTuple_GET_SIZE(child_pids); for (const auto idx : c10::irange(size)) { PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); pids_set.insert(static_cast(THPUtils_unpackLong(obj))); } worker_pids[key] = pids_set; Py_RETURN_NONE; END_HANDLE_TH_ERRORS } static PyObject* THPModule_removeWorkerPIDs( PyObject* module, PyObject* loader_id) { HANDLE_TH_ERRORS int64_t key = THPUtils_unpackLong(loader_id); auto it = worker_pids.find(key); TORCH_CHECK_VALUE( it != worker_pids.end(), "Cannot find worker information for _BaseDataLoaderIter with id ", key); worker_pids.erase(it); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } #undef SIGNAL_HANDLER #else // dummy implementations for windows static PyObject* THPModule_setWorkerSignalHandlers( PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* THPModule_removeWorkerPIDs( PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } static PyObject* THPModule_errorIfAnyWorkerFails( PyObject* module, PyObject* _ignored) { Py_RETURN_NONE; } #endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) PyMethodDef DataLoaderMethods[] = { {"_set_worker_signal_handlers", THPModule_setWorkerSignalHandlers, METH_NOARGS, nullptr}, {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr}, {"_error_if_any_worker_fails", THPModule_errorIfAnyWorkerFails, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr}};