xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_engine.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_engine.h>
2 
3 #include <ATen/LegacyBatchedTensorImpl.h>
4 #include <ATen/LegacyVmapMode.h>
5 #include <c10/util/irange.h>
6 #include <pybind11/pybind11.h>
7 #include <torch/csrc/DynamicTypes.h>
8 #include <torch/csrc/THP.h>
9 #include <torch/csrc/autograd/edge.h>
10 #include <torch/csrc/autograd/engine.h>
11 #include <torch/csrc/autograd/function.h>
12 #include <torch/csrc/autograd/functions/basic_ops.h>
13 #include <torch/csrc/autograd/python_anomaly_mode.h>
14 #include <torch/csrc/autograd/python_cpp_function.h>
15 #include <torch/csrc/autograd/python_function.h>
16 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
17 #include <torch/csrc/utils/pybind.h>
18 #include <torch/csrc/utils/pycfunction_helpers.h>
19 
20 #ifndef _WIN32
21 #include <pthread.h>
22 #endif
23 
24 #include <memory> // for unique_ptr
25 #include <utility>
26 
27 using namespace torch::autograd;
28 
29 struct THPEngine {
30   PyObject_HEAD
31 };
32 
33 static bool _reinitialize_engine = false;
34 
35 namespace torch::autograd::python {
36 
37 PythonEngine::PythonEngine() = default;
38 
get_python_engine()39 Engine& PythonEngine::get_python_engine() {
40   static PythonEngine engine;
41   // This is "probably" thread-safe because the flag is set in a fork handler
42   // before any threads are created, and this function is only called with the
43   // GIL held. However, using fork + threads is playing with fire so this is
44   // more of a "best effort" thing. For example, if the fork occurs while the
45   // backwards threads hold a lock, we'll probably deadlock in the engine
46   // destructor.
47   if (_reinitialize_engine) {
48     engine.release_workers();
49     engine.~PythonEngine();
50     new (&engine) torch::autograd::python::PythonEngine();
51     _reinitialize_engine = false;
52   }
53   return engine;
54 }
55 
~PythonEngine()56 PythonEngine::~PythonEngine() {
57   Engine::stop();
58 }
59 
60 #if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9
61 #define IS_PYTHON_3_9_PLUS
62 #endif
63 
thread_init(int device,const std::shared_ptr<ReadyQueue> & ready_queue,bool should_increment)64 void PythonEngine::thread_init(
65     int device,
66     const std::shared_ptr<ReadyQueue>& ready_queue,
67     bool should_increment) {
68   // Increment thread usage count before acquiring the GIL
69   if (should_increment) {
70     increment_non_reentrant_thread_count();
71   }
72   // Create a PyThreadState, but release the GIL. This lets
73   // pybind11::gil_scoped_acquire calls inside thread_main acquire the GIL
74   // without having to create a new PyThreadState each time.
75 #if defined(IS_PYTHON_3_9_PLUS)
76   auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
77 #else
78   pybind11::gil_scoped_acquire gil;
79 #endif
80   pybind11::gil_scoped_release no_gil;
81   Engine::thread_init(device, ready_queue, false);
82 
83   if (should_increment) {
84     // Decrement the count during shutdown if we incremented earlier.
85     decrement_non_reentrant_thread_count();
86   }
87 
88 #if defined(IS_PYTHON_3_9_PLUS)
89   // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if
90   // runtime is finalizing
91   if (!Py_IsInitialized()) {
92     no_gil.disarm();
93     // TODO: call disarm once PyThreadState_Clear can safely be called from
94     // finalize NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct
95     // PyThreadState, so avoid use-after-free here.
96     auto ptr = gil.release();
97     operator delete(ptr);
98   }
99 #endif
100 }
101 
thread_on_exception(const std::shared_ptr<GraphTask> & graph_task,const std::shared_ptr<Node> & fn,std::exception & e)102 void PythonEngine::thread_on_exception(
103     const std::shared_ptr<GraphTask>& graph_task,
104     const std::shared_ptr<Node>& fn,
105     std::exception& e) {
106   // See Note [ Persisting PyErr state across autograd engine threads ]
107   auto python_err = dynamic_cast<python_error*>(&e);
108   if (python_err) {
109     python_err->persist();
110   }
111   Engine::thread_on_exception(graph_task, fn, e);
112 }
113 
make_anomaly_metadata()114 std::unique_ptr<AnomalyMetadata> PythonEngine::make_anomaly_metadata() {
115   return std::make_unique<PyAnomalyMetadata>();
116 }
117 
118 std::unique_ptr<SavedVariableHooks> PythonEngine::
get_default_saved_variable_hooks()119     get_default_saved_variable_hooks() {
120   return PyDefaultSavedVariableHooks::get_hooks();
121 }
122 
execute(const edge_list & roots,const variable_list & inputs,bool keep_graph,bool create_graph,bool accumulate_grad,const edge_list & outputs)123 variable_list PythonEngine::execute(
124     const edge_list& roots,
125     const variable_list& inputs,
126     bool keep_graph,
127     bool create_graph,
128     bool accumulate_grad,
129     const edge_list& outputs) {
130   TORCH_CHECK(
131       !PyGILState_Check(),
132       "The autograd engine was called while holding the GIL. If you are using the C++ "
133       "API, the autograd engine is an expensive operation that does not require the "
134       "GIL to be held so you should release it with 'pybind11::gil_scoped_release no_gil;'"
135       ". If you are not using the C++ API, please report a bug to the pytorch team.")
136   try {
137     return Engine::execute(
138         roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
139   } catch (python_error& e) {
140     e.restore();
141     throw;
142   }
143 }
144 
execute_with_graph_task(const std::shared_ptr<GraphTask> & graph_task,std::shared_ptr<Node> graph_root,InputBuffer && input_buffer)145 c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
146     const std::shared_ptr<GraphTask>& graph_task,
147     std::shared_ptr<Node> graph_root,
148     InputBuffer&& input_buffer) {
149   try {
150     return Engine::execute_with_graph_task(
151         graph_task, std::move(graph_root), std::move(input_buffer));
152   } catch (python_error& e) {
153     pybind11::gil_scoped_acquire gil;
154     if (!PyErr_Occurred()) {
155       // Set the error indicator only if it is not set already.
156       e.restore();
157     }
158     throw;
159   }
160 }
161 } // namespace torch::autograd::python
162 
163 PyObject* THPEngineClass = nullptr;
164 
parseGradientEdge(PyObject * obj,int64_t index)165 inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
166   PyObject* grad_fn = PyTuple_GetItem(obj, 0);
167   auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
168   std::shared_ptr<torch::autograd::Node> grad_fn_sp;
169   if (THPFunction_Check(grad_fn)) {
170     grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
171   } else if (THPCppFunction_Check(grad_fn)) {
172     grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
173   } else {
174     TORCH_CHECK(
175         false,
176         "GradientEdge's first object must be an autograd.graph.Node "
177         "but got ",
178         THPUtils_typename(grad_fn));
179   }
180   return Edge(grad_fn_sp, output_nr);
181 }
182 
183 // Implementation of torch._C._EngineBase.run_backward
THPEngine_run_backward(PyObject * self,PyObject * args,PyObject * kwargs)184 PyObject* THPEngine_run_backward(
185     PyObject* self,
186     PyObject* args,
187     PyObject* kwargs) {
188   HANDLE_TH_ERRORS
189   PyObject* tensors = nullptr;
190   PyObject* grad_tensors = nullptr;
191   unsigned char keep_graph = 0;
192   unsigned char create_graph = 0;
193   PyObject* inputs = nullptr;
194   unsigned char allow_unreachable = 0;
195   unsigned char accumulate_grad =
196       0; // Indicate whether to accumulate grad into leaf Tensors or capture
197   constexpr const char* accepted_kwargs[] = {// NOLINT
198                                              "tensors",
199                                              "grad_tensors",
200                                              "keep_graph",
201                                              "create_graph",
202                                              "inputs",
203                                              "allow_unreachable",
204                                              "accumulate_grad",
205                                              nullptr};
206   if (!PyArg_ParseTupleAndKeywords(
207           args,
208           kwargs,
209           "OObb|Obb",
210           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,-warnings-as-errors)
211           const_cast<char**>(accepted_kwargs),
212           &tensors,
213           &grad_tensors,
214           &keep_graph,
215           &create_graph,
216           &inputs,
217           &allow_unreachable,
218           &accumulate_grad))
219     return nullptr;
220   TORCH_CHECK(
221       PyTuple_Check(tensors),
222       "tensors argument is expected to "
223       "be a tuple, but got ",
224       THPUtils_typename(tensors));
225   TORCH_CHECK(
226       PyTuple_Check(grad_tensors),
227       "grad_tensors argument is "
228       "expected to be a tuple, but got ",
229       THPUtils_typename(grad_tensors));
230 
231   Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors);
232   Py_ssize_t num_gradients = PyTuple_GET_SIZE(grad_tensors);
233   TORCH_CHECK(
234       num_tensors == num_gradients,
235       "got ",
236       num_tensors,
237       " tensors and ",
238       num_gradients,
239       " gradients");
240 
241   // The user either called autograd.backward(...) or autograd.grad(...) to get
242   // here
243   bool backward_api_called = accumulate_grad;
244   TORCH_CHECK(
245       !backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
246       "backward() called inside torch.vmap. This is not supported, "
247       "please call backward() outside torch.vmap or instead use "
248       "torch.autograd.grad inside torch.vmap");
249 
250   edge_list roots;
251   roots.reserve(num_tensors);
252   variable_list grads;
253   grads.reserve(num_tensors);
254   for (const auto i : c10::irange(num_tensors)) {
255     PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
256     Edge gradient_edge; // Temporary variable to hold the gradient edge
257     std::optional<at::Tensor> mb_output;
258     if (THPVariable_Check(_tensor)) {
259       mb_output = THPVariable_Unpack(_tensor);
260       TORCH_CHECK(
261           !isBatchedTensor(mb_output.value()),
262           "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
263           "torch.vmap. We do not support the case where any outputs are ",
264           "vmapped tensors (output ",
265           i,
266           " is being vmapped over). Please "
267           "call autograd.grad() outside torch.vmap or file a bug report "
268           "with your use case.");
269       gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
270     } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
271       gradient_edge = parseGradientEdge(_tensor, i);
272     } else {
273       TORCH_CHECK(
274           false,
275           "element ",
276           i,
277           " of tensors tuple is neither a Tensor nor a GradientEdge");
278     }
279     TORCH_CHECK(
280         gradient_edge.function,
281         "element ",
282         i,
283         " of tensors does not require grad and does not have a grad_fn");
284     roots.push_back(std::move(gradient_edge));
285 
286     PyObject* grad = PyTuple_GET_ITEM(grad_tensors, i);
287     if (THPVariable_Check(grad)) {
288       const Variable& grad_var = THPVariable_Unpack(grad);
289       if (grad_var.has_names()) {
290         TORCH_WARN(
291             "Autograd was passed a named grad tensor with dims ",
292             grad_var.names(),
293             ". Autograd does not yet support named tensor semantics, so all names ",
294             "will be ignored. In practice all computed gradients will still be correct "
295             "according to regular tensor semantics.");
296       }
297       grads.push_back(grad_var);
298     } else {
299       TORCH_CHECK(
300           grad == Py_None,
301           "element ",
302           i,
303           " of gradients tuple is not a Tensor or None");
304       TORCH_CHECK(
305           mb_output.has_value(),
306           "element ",
307           i,
308           " of gradients tuple is None, but the corresponding output is a GradientEdge."
309           "This is not supported.");
310       TORCH_CHECK(
311           !mb_output.value().requires_grad(),
312           "element ",
313           i,
314           " of gradients tuple is None, but the corresponding Tensor requires grad");
315     }
316   }
317 
318   std::vector<Edge> output_edges;
319   if (inputs != nullptr) {
320     TORCH_CHECK(
321         PyTuple_CheckExact(inputs), "inputs to run_backward must be a tuple");
322     int num_inputs = PyTuple_GET_SIZE(inputs);
323     output_edges.reserve(num_inputs);
324     for (const auto i : c10::irange(num_inputs)) {
325       PyObject* input = PyTuple_GET_ITEM(inputs, i);
326       if (THPVariable_Check(input)) {
327         const auto& tensor = THPVariable_Unpack(input);
328         TORCH_CHECK(
329             !isBatchedTensor(tensor),
330             "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
331             "torch.vmap. We do not support the case where any inputs are ",
332             "vmapped tensors (input ",
333             i,
334             " is being vmapped over). Please "
335             "call autograd.grad() outside torch.vmap or file a bug report "
336             "with your use case.")
337         const auto output_nr = tensor.output_nr();
338         auto grad_fn = tensor.grad_fn();
339         if (!grad_fn) {
340           grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
341         }
342         if (accumulate_grad) {
343           tensor.retain_grad();
344         }
345         TORCH_CHECK(
346             tensor.requires_grad(),
347             "One of the differentiated Tensors does not require grad");
348         if (!grad_fn) {
349           // NOTE [ Autograd Unreachable Input ]
350           // Since input has no grad_accumulator, its guaranteed to be
351           // unreachable. We initialize an edge pointing to a non-nullptr Node
352           // so nodes in the graph (e.g., mul when an operand is scalar) that
353           // have edges pointing to nullptr don't get erroneously assigned
354           // `needed = True` in exec_info.
355           output_edges.emplace_back(std::make_shared<Identity>(), 0);
356         } else {
357           output_edges.emplace_back(grad_fn, output_nr);
358         }
359       } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
360         output_edges.emplace_back(parseGradientEdge(input, i));
361       } else {
362         TORCH_CHECK(
363             false,
364             "all inputs have to be Tensors or GradientEdges, but got ",
365             THPUtils_typename(input));
366       }
367     }
368   }
369 
370   variable_list outputs;
371   {
372     pybind11::gil_scoped_release no_gil;
373     auto& engine = python::PythonEngine::get_python_engine();
374     outputs = engine.execute(
375         roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
376   }
377 
378   if (!backward_api_called && inputs != nullptr) {
379     int num_inputs = PyTuple_GET_SIZE(inputs);
380     THPObjectPtr py_outputs{PyTuple_New(num_inputs)};
381     if (!py_outputs)
382       return nullptr;
383     for (const auto i : c10::irange(num_inputs)) {
384       TORCH_CHECK(
385           allow_unreachable || outputs[i].defined(),
386           "One of the "
387           "differentiated Tensors appears to not have been used "
388           "in the graph. Set allow_unused=True if this is the "
389           "desired behavior.");
390       PyTuple_SET_ITEM(py_outputs.get(), i, THPVariable_Wrap(outputs[i]));
391     }
392     return py_outputs.release();
393   } else {
394     Py_RETURN_NONE;
395   }
396   END_HANDLE_TH_ERRORS
397 }
398 
THPEngine_queue_callback(PyObject * self,PyObject * _callback)399 PyObject* THPEngine_queue_callback(PyObject* self, PyObject* _callback) {
400   HANDLE_TH_ERRORS
401   auto& engine = python::PythonEngine::get_python_engine();
402   std::shared_ptr<PyObject> callback(_callback, [](PyObject* obj) {
403     pybind11::gil_scoped_acquire gil;
404     Py_DECREF(obj);
405   });
406   Py_INCREF(_callback);
407   engine.queue_callback([callback]() {
408     pybind11::gil_scoped_acquire gil;
409     THPObjectPtr result{PyObject_CallFunctionObjArgs(callback.get(), nullptr)};
410     if (!result) {
411       // Note [ Persisting PyErr state across autograd engine threads ]
412       //
413       // Since the autograd engine is multi-threaded, and Python error state is
414       // local to each thread, it must preserve the python error from the worker
415       // thread and rethrow it as-is in the calling thread. This is done via
416       // persisting the error in the two places that can encounter Python
417       // errors: (1) evaluate function and (2) queued callbacks.
418       //
419       // TODO: the engine is not actually responsible for persisting the error
420       // in the custom autograd Function case today! See the note above
421       // `raise_python_error()` function in python_function.cpp and
422       // python_hooks.cpp for more details. Persisting an extra time in the
423       // engine is fine because doing so is a no-op when the python_error has
424       // already been persisted.
425       python_error err;
426       err.persist();
427       throw std::move(err);
428     }
429   });
430   Py_RETURN_NONE;
431   END_HANDLE_TH_ERRORS
432 }
433 
THPEngine_is_checkpoint_valid(PyObject * self,PyObject * noargs)434 PyObject* THPEngine_is_checkpoint_valid(PyObject* self, PyObject* noargs) {
435   HANDLE_TH_ERRORS
436   auto& engine = python::PythonEngine::get_python_engine();
437   if (engine.is_checkpoint_valid()) {
438     Py_RETURN_TRUE;
439   } else {
440     Py_RETURN_FALSE;
441   }
442   END_HANDLE_TH_ERRORS
443 }
444 
THPEngine_new(PyTypeObject * type,PyObject * args,PyObject * kwargs)445 PyObject* THPEngine_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
446   return type->tp_alloc(type, 0);
447 }
448 
449 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
450 static struct PyMethodDef THPEngine_methods[] = {
451     {(char*)"run_backward",
452      castPyCFunctionWithKeywords(THPEngine_run_backward),
453      METH_VARARGS | METH_KEYWORDS,
454      nullptr},
455     {(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
456     {(char*)"is_checkpoint_valid",
457      THPEngine_is_checkpoint_valid,
458      METH_NOARGS,
459      nullptr},
460     {nullptr}};
461 
462 PyTypeObject THPEngineType = {
463     PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._EngineBase", /* tp_name */
464     sizeof(THPEngine), /* tp_basicsize */
465     0, /* tp_itemsize */
466     nullptr, /* tp_dealloc */
467     0, /* tp_vectorcall_offset */
468     nullptr, /* tp_getattr */
469     nullptr, /* tp_setattr */
470     nullptr, /* tp_reserved */
471     nullptr, /* tp_repr */
472     nullptr, /* tp_as_number */
473     nullptr, /* tp_as_sequence */
474     nullptr, /* tp_as_mapping */
475     nullptr, /* tp_hash  */
476     nullptr, /* tp_call */
477     nullptr, /* tp_str */
478     nullptr, /* tp_getattro */
479     nullptr, /* tp_setattro */
480     nullptr, /* tp_as_buffer */
481     // NOLINTNEXTLINE(misc-redundant-expression)
482     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
483     nullptr, /* tp_doc */
484     nullptr, /* tp_traverse */
485     nullptr, /* tp_clear */
486     nullptr, /* tp_richcompare */
487     0, /* tp_weaklistoffset */
488     nullptr, /* tp_iter */
489     nullptr, /* tp_iternext */
490     THPEngine_methods, /* tp_methods */
491     nullptr, /* tp_members */
492     nullptr, /* tp_getset */
493     nullptr, /* tp_base */
494     nullptr, /* tp_dict */
495     nullptr, /* tp_descr_get */
496     nullptr, /* tp_descr_set */
497     0, /* tp_dictoffset */
498     nullptr, /* tp_init */
499     nullptr, /* tp_alloc */
500     THPEngine_new /* tp_new */
501 };
502 
child_atfork()503 static void child_atfork() {
504   _reinitialize_engine = true;
505 }
506 
THPEngine_initModule(PyObject * module)507 bool THPEngine_initModule(PyObject* module) {
508 #ifndef _WIN32
509   if (pthread_atfork(nullptr, nullptr, child_atfork) != 0) {
510     throw std::runtime_error("unable to set pthread_atfork handler");
511   }
512 #endif
513   if (PyType_Ready(&THPEngineType) < 0)
514     return false;
515   Py_INCREF(&THPEngineType);
516   PyModule_AddObject(module, "_ImperativeEngine", (PyObject*)&THPEngineType);
517   set_default_engine_stub(python::PythonEngine::get_python_engine);
518   return true;
519 }
520