1 #include <torch/csrc/autograd/python_engine.h>
2
3 #include <ATen/LegacyBatchedTensorImpl.h>
4 #include <ATen/LegacyVmapMode.h>
5 #include <c10/util/irange.h>
6 #include <pybind11/pybind11.h>
7 #include <torch/csrc/DynamicTypes.h>
8 #include <torch/csrc/THP.h>
9 #include <torch/csrc/autograd/edge.h>
10 #include <torch/csrc/autograd/engine.h>
11 #include <torch/csrc/autograd/function.h>
12 #include <torch/csrc/autograd/functions/basic_ops.h>
13 #include <torch/csrc/autograd/python_anomaly_mode.h>
14 #include <torch/csrc/autograd/python_cpp_function.h>
15 #include <torch/csrc/autograd/python_function.h>
16 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
17 #include <torch/csrc/utils/pybind.h>
18 #include <torch/csrc/utils/pycfunction_helpers.h>
19
20 #ifndef _WIN32
21 #include <pthread.h>
22 #endif
23
24 #include <memory> // for unique_ptr
25 #include <utility>
26
27 using namespace torch::autograd;
28
29 struct THPEngine {
30 PyObject_HEAD
31 };
32
33 static bool _reinitialize_engine = false;
34
35 namespace torch::autograd::python {
36
37 PythonEngine::PythonEngine() = default;
38
get_python_engine()39 Engine& PythonEngine::get_python_engine() {
40 static PythonEngine engine;
41 // This is "probably" thread-safe because the flag is set in a fork handler
42 // before any threads are created, and this function is only called with the
43 // GIL held. However, using fork + threads is playing with fire so this is
44 // more of a "best effort" thing. For example, if the fork occurs while the
45 // backwards threads hold a lock, we'll probably deadlock in the engine
46 // destructor.
47 if (_reinitialize_engine) {
48 engine.release_workers();
49 engine.~PythonEngine();
50 new (&engine) torch::autograd::python::PythonEngine();
51 _reinitialize_engine = false;
52 }
53 return engine;
54 }
55
~PythonEngine()56 PythonEngine::~PythonEngine() {
57 Engine::stop();
58 }
59
60 #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
61 #define IS_PYTHON_3_9_PLUS
62 #endif
63
thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)64 void PythonEngine::thread_init(
65 int device,
66 const std::shared_ptr<ReadyQueue>& ready_queue,
67 bool should_increment) {
68 // Increment thread usage count before acquiring the GIL
69 if (should_increment) {
70 increment_non_reentrant_thread_count();
71 }
72 // Create a PyThreadState, but release the GIL. This lets
73 // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
74 // without having to create a new PyThreadState each time.
75 #if defined(IS_PYTHON_3_9_PLUS)
76 auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
77 #else
78 pybind11::gil_scoped_acquire gil;
79 #endif
80 pybind11::gil_scoped_release no_gil;
81 Engine::thread_init(device, ready_queue, false);
82
83 if (should_increment) {
84 // Decrement the count during shutdown if we incremented earlier.
85 decrement_non_reentrant_thread_count();
86 }
87
88 #if defined(IS_PYTHON_3_9_PLUS)
89 // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
90 // runtime is finalizing
91 if (!Py_IsInitialized()) {
92 no_gil.disarm();
93 // TODO: call disarm once PyThreadState_Clear can safely be called from
94 // finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct
95 // PyThreadState, so avoid use-after-free here.
96 auto ptr = gil.release();
97 operator delete(ptr);
98 }
99 #endif
100 }
101
thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)102 void PythonEngine::thread_on_exception(
103 const std::shared_ptr<GraphTask>& graph_task,
104 const std::shared_ptr<Node>& fn,
105 std::exception& e) {
106 // See Note [ Persisting PyErr state across autograd engine threads ]
107 auto python_err = dynamic_cast<python_error*>(&e);
108 if (python_err) {
109 python_err->persist();
110 }
111 Engine::thread_on_exception(graph_task, fn, e);
112 }
113
make_anomaly_metadata()114 std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
115 return std::make_unique<PyAnomalyMetadata>();
116 }
117
118 std::unique_ptr<SavedVariableHooks> PythonEngine::
get_default_saved_variable_hooks()119 get_default_saved_variable_hooks() {
120 return PyDefaultSavedVariableHooks::get_hooks();
121 }
122
execute(const edge_list & roots,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)123 variable_list PythonEngine::execute(
124 const edge_list& roots,
125 const variable_list& inputs,
126 bool keep_graph,
127 bool create_graph,
128 bool accumulate_grad,
129 const edge_list& outputs) {
130 TORCH_CHECK(
131 !PyGILState_Check(),
132 "The autograd engine was called while holding the GIL. If you are using the C++ "
133 "API, the autograd engine is an expensive operation that does not require the "
134 "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
135 ". If you are not using the C++ API, please report a bug to the pytorch team.")
136 try {
137 return Engine::execute(
138 roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
139 } catch (python_error& e) {
140 e.restore();
141 throw;
142 }
143 }
144
execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)145 c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
146 const std::shared_ptr<GraphTask>& graph_task,
147 std::shared_ptr<Node> graph_root,
148 InputBuffer&& input_buffer) {
149 try {
150 return Engine::execute_with_graph_task(
151 graph_task, std::move(graph_root), std::move(input_buffer));
152 } catch (python_error& e) {
153 pybind11::gil_scoped_acquire gil;
154 if (!PyErr_Occurred()) {
155 // Set the error indicator only if it is not set already.
156 e.restore();
157 }
158 throw;
159 }
160 }
161 } // namespace torch::autograd::python
162
163 PyObject* THPEngineClass = nullptr;
164
parseGradientEdge(PyObject * obj,int64_t index)165 inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
166 PyObject* grad_fn = PyTuple_GetItem(obj, 0);
167 auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
168 std::shared_ptr<torch::autograd::Node> grad_fn_sp;
169 if (THPFunction_Check(grad_fn)) {
170 grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
171 } else if (THPCppFunction_Check(grad_fn)) {
172 grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
173 } else {
174 TORCH_CHECK(
175 false,
176 "GradientEdge's first object must be an autograd.graph.Node "
177 "but got ",
178 THPUtils_typename(grad_fn));
179 }
180 return Edge(grad_fn_sp, output_nr);
181 }
182
183 // Implementation of torch._C._EngineBase.run_backward
THPEngine_run_backward(PyObject * self,PyObject * args,PyObject * kwargs)184 PyObject* THPEngine_run_backward(
185 PyObject* self,
186 PyObject* args,
187 PyObject* kwargs) {
188 HANDLE_TH_ERRORS
189 PyObject* tensors = nullptr;
190 PyObject* grad_tensors = nullptr;
191 unsigned char keep_graph = 0;
192 unsigned char create_graph = 0;
193 PyObject* inputs = nullptr;
194 unsigned char allow_unreachable = 0;
195 unsigned char accumulate_grad =
196 0; // Indicate whether to accumulate grad into leaf Tensors or capture
197 constexpr const char* accepted_kwargs[] = {// NOLINT
198 "tensors",
199 "grad_tensors",
200 "keep_graph",
201 "create_graph",
202 "inputs",
203 "allow_unreachable",
204 "accumulate_grad",
205 nullptr};
206 if (!PyArg_ParseTupleAndKeywords(
207 args,
208 kwargs,
209 "OObb|Obb",
210 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors)
211 const_cast<char**>(accepted_kwargs),
212 &tensors,
213 &grad_tensors,
214 &keep_graph,
215 &create_graph,
216 &inputs,
217 &allow_unreachable,
218 &accumulate_grad))
219 return nullptr;
220 TORCH_CHECK(
221 PyTuple_Check(tensors),
222 "tensors argument is expected to "
223 "be a tuple, but got ",
224 THPUtils_typename(tensors));
225 TORCH_CHECK(
226 PyTuple_Check(grad_tensors),
227 "grad_tensors argument is "
228 "expected to be a tuple, but got ",
229 THPUtils_typename(grad_tensors));
230
231 Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
232 Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
233 TORCH_CHECK(
234 num_tensors == num_gradients,
235 "got ",
236 num_tensors,
237 " tensors and ",
238 num_gradients,
239 " gradients");
240
241 // The user either called autograd.backward(...) or autograd.grad(...) to get
242 // here
243 bool backward_api_called = accumulate_grad;
244 TORCH_CHECK(
245 !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
246 "backward() called inside torch.vmap. This is not supported, "
247 "please call backward() outside torch.vmap or instead use "
248 "torch.autograd.grad inside torch.vmap");
249
250 edge_list roots;
251 roots.reserve(num_tensors);
252 variable_list grads;
253 grads.reserve(num_tensors);
254 for (const auto i : c10::irange(num_tensors)) {
255 PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
256 Edge gradient_edge; // Temporary variable to hold the gradient edge
257 std::optional<at::Tensor> mb_output;
258 if (THPVariable_Check(_tensor)) {
259 mb_output = THPVariable_Unpack(_tensor);
260 TORCH_CHECK(
261 !isBatchedTensor(mb_output.value()),
262 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
263 "torch.vmap. We do not support the case where any outputs are ",
264 "vmapped tensors (output ",
265 i,
266 " is being vmapped over). Please "
267 "call autograd.grad() outside torch.vmap or file a bug report "
268 "with your use case.");
269 gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
270 } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
271 gradient_edge = parseGradientEdge(_tensor, i);
272 } else {
273 TORCH_CHECK(
274 false,
275 "element ",
276 i,
277 " of tensors tuple is neither a Tensor nor a GradientEdge");
278 }
279 TORCH_CHECK(
280 gradient_edge.function,
281 "element ",
282 i,
283 " of tensors does not require grad and does not have a grad_fn");
284 roots.push_back(std::move(gradient_edge));
285
286 PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
287 if (THPVariable_Check(grad)) {
288 const Variable& grad_var = THPVariable_Unpack(grad);
289 if (grad_var.has_names()) {
290 TORCH_WARN(
291 "Autograd was passed a named grad tensor with dims ",
292 grad_var.names(),
293 ". Autograd does not yet support named tensor semantics, so all names ",
294 "will be ignored. In practice all computed gradients will still be correct "
295 "according to regular tensor semantics.");
296 }
297 grads.push_back(grad_var);
298 } else {
299 TORCH_CHECK(
300 grad == Py_None,
301 "element ",
302 i,
303 " of gradients tuple is not a Tensor or None");
304 TORCH_CHECK(
305 mb_output.has_value(),
306 "element ",
307 i,
308 " of gradients tuple is None, but the corresponding output is a GradientEdge."
309 "This is not supported.");
310 TORCH_CHECK(
311 !mb_output.value().requires_grad(),
312 "element ",
313 i,
314 " of gradients tuple is None, but the corresponding Tensor requires grad");
315 }
316 }
317
318 std::vector<Edge> output_edges;
319 if (inputs != nullptr) {
320 TORCH_CHECK(
321 PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple");
322 int num_inputs = PyTuple_GET_SIZE(inputs);
323 output_edges.reserve(num_inputs);
324 for (const auto i : c10::irange(num_inputs)) {
325 PyObject* input = PyTuple_GET_ITEM(inputs, i);
326 if (THPVariable_Check(input)) {
327 const auto& tensor = THPVariable_Unpack(input);
328 TORCH_CHECK(
329 !isBatchedTensor(tensor),
330 "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
331 "torch.vmap. We do not support the case where any inputs are ",
332 "vmapped tensors (input ",
333 i,
334 " is being vmapped over). Please "
335 "call autograd.grad() outside torch.vmap or file a bug report "
336 "with your use case.")
337 const auto output_nr = tensor.output_nr();
338 auto grad_fn = tensor.grad_fn();
339 if (!grad_fn) {
340 grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
341 }
342 if (accumulate_grad) {
343 tensor.retain_grad();
344 }
345 TORCH_CHECK(
346 tensor.requires_grad(),
347 "One of the differentiated Tensors does not require grad");
348 if (!grad_fn) {
349 // NOTE [ Autograd Unreachable Input ]
350 // Since input has no grad_accumulator, its guaranteed to be
351 // unreachable. We initialize an edge pointing to a non-nullptr Node
352 // so nodes in the graph (e.g., mul when an operand is scalar) that
353 // have edges pointing to nullptr don't get erroneously assigned
354 // `needed = True` in exec_info.
355 output_edges.emplace_back(std::make_shared<Identity>(), 0);
356 } else {
357 output_edges.emplace_back(grad_fn, output_nr);
358 }
359 } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
360 output_edges.emplace_back(parseGradientEdge(input, i));
361 } else {
362 TORCH_CHECK(
363 false,
364 "all inputs have to be Tensors or GradientEdges, but got ",
365 THPUtils_typename(input));
366 }
367 }
368 }
369
370 variable_list outputs;
371 {
372 pybind11::gil_scoped_release no_gil;
373 auto& engine = python::PythonEngine::get_python_engine();
374 outputs = engine.execute(
375 roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
376 }
377
378 if (!backward_api_called && inputs != nullptr) {
379 int num_inputs = PyTuple_GET_SIZE(inputs);
380 THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
381 if (!py_outputs)
382 return nullptr;
383 for (const auto i : c10::irange(num_inputs)) {
384 TORCH_CHECK(
385 allow_unreachable || outputs[i].defined(),
386 "One of the "
387 "differentiated Tensors appears to not have been used "
388 "in the graph. Set allow_unused=True if this is the "
389 "desired behavior.");
390 PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
391 }
392 return py_outputs.release();
393 } else {
394 Py_RETURN_NONE;
395 }
396 END_HANDLE_TH_ERRORS
397 }
398
THPEngine_queue_callback(PyObject * self,PyObject * _callback)399 PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
400 HANDLE_TH_ERRORS
401 auto& engine = python::PythonEngine::get_python_engine();
402 std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
403 pybind11::gil_scoped_acquire gil;
404 Py_DECREF(obj);
405 });
406 Py_INCREF(_callback);
407 engine.queue_callback([callback]() {
408 pybind11::gil_scoped_acquire gil;
409 THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
410 if (!result) {
411 // Note [ Persisting PyErr state across autograd engine threads ]
412 //
413 // Since the autograd engine is multi-threaded, and Python error state is
414 // local to each thread, it must preserve the python error from the worker
415 // thread and rethrow it as-is in the calling thread. This is done via
416 // persisting the error in the two places that can encounter Python
417 // errors: (1) evaluate function and (2) queued callbacks.
418 //
419 // TODO: the engine is not actually responsible for persisting the error
420 // in the custom autograd Function case today! See the note above
421 // `raise_python_error()` function in python_function.cpp and
422 // python_hooks.cpp for more details. Persisting an extra time in the
423 // engine is fine because doing so is a no-op when the python_error has
424 // already been persisted.
425 python_error err;
426 err.persist();
427 throw std::move(err);
428 }
429 });
430 Py_RETURN_NONE;
431 END_HANDLE_TH_ERRORS
432 }
433
THPEngine_is_checkpoint_valid(PyObject * self,PyObject * noargs)434 PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
435 HANDLE_TH_ERRORS
436 auto& engine = python::PythonEngine::get_python_engine();
437 if (engine.is_checkpoint_valid()) {
438 Py_RETURN_TRUE;
439 } else {
440 Py_RETURN_FALSE;
441 }
442 END_HANDLE_TH_ERRORS
443 }
444
THPEngine_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)445 PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
446 return type->tp_alloc(type, 0);
447 }
448
449 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
450 static struct PyMethodDef THPEngine_methods[] = {
451 {(char*)"run_backward",
452 castPyCFunctionWithKeywords(THPEngine_run_backward),
453 METH_VARARGS | METH_KEYWORDS,
454 nullptr},
455 {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
456 {(char*)"is_checkpoint_valid",
457 THPEngine_is_checkpoint_valid,
458 METH_NOARGS,
459 nullptr},
460 {nullptr}};
461
462 PyTypeObject THPEngineType = {
463 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
464 sizeof(THPEngine), /* tp_basicsize */
465 0, /* tp_itemsize */
466 nullptr, /* tp_dealloc */
467 0, /* tp_vectorcall_offset */
468 nullptr, /* tp_getattr */
469 nullptr, /* tp_setattr */
470 nullptr, /* tp_reserved */
471 nullptr, /* tp_repr */
472 nullptr, /* tp_as_number */
473 nullptr, /* tp_as_sequence */
474 nullptr, /* tp_as_mapping */
475 nullptr, /* tp_hash */
476 nullptr, /* tp_call */
477 nullptr, /* tp_str */
478 nullptr, /* tp_getattro */
479 nullptr, /* tp_setattro */
480 nullptr, /* tp_as_buffer */
481 // NOLINTNEXTLINE(misc-redundant-expression)
482 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
483 nullptr, /* tp_doc */
484 nullptr, /* tp_traverse */
485 nullptr, /* tp_clear */
486 nullptr, /* tp_richcompare */
487 0, /* tp_weaklistoffset */
488 nullptr, /* tp_iter */
489 nullptr, /* tp_iternext */
490 THPEngine_methods, /* tp_methods */
491 nullptr, /* tp_members */
492 nullptr, /* tp_getset */
493 nullptr, /* tp_base */
494 nullptr, /* tp_dict */
495 nullptr, /* tp_descr_get */
496 nullptr, /* tp_descr_set */
497 0, /* tp_dictoffset */
498 nullptr, /* tp_init */
499 nullptr, /* tp_alloc */
500 THPEngine_new /* tp_new */
501 };
502
child_atfork()503 static void child_atfork() {
504 _reinitialize_engine = true;
505 }
506
THPEngine_initModule(PyObject * module)507 bool THPEngine_initModule(PyObject* module) {
508 #ifndef _WIN32
509 if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
510 throw std::runtime_error("unable to set pthread_atfork handler");
511 }
512 #endif
513 if (PyType_Ready(&THPEngineType) < 0)
514 return false;
515 Py_INCREF(&THPEngineType);
516 PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
517 set_default_engine_stub(python::PythonEngine::get_python_engine);
518 return true;
519 }
520