xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_engine.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 
5 #include <torch/csrc/autograd/engine.h>
6 #include <torch/csrc/autograd/function.h>
7 
8 bool THPEngine_initModule(PyObject* module);
9 
10 namespace torch::autograd::python {
11 
12 struct PythonEngine : public Engine {
13   static Engine& get_python_engine();
14   ~PythonEngine() override;
15   void thread_init(
16       int device,
17       const std::shared_ptr<ReadyQueue>& ready_queue,
18       bool should_increment) override;
19   void thread_on_exception(
20       const std::shared_ptr<GraphTask>& graph_task,
21       const std::shared_ptr<Node>& fn,
22       std::exception& e) override;
23   variable_list execute(
24       const edge_list& roots,
25       const variable_list& inputs,
26       bool keep_graph,
27       bool create_graph,
28       bool accumulate_grad,
29       const edge_list& outputs = {}) override;
30 
31   c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
32       const std::shared_ptr<GraphTask>& graph_task,
33       std::shared_ptr<Node> graph_root,
34       InputBuffer&& input_buffer) override;
35 
36   std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override;
37   std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks()
38       override;
39 
40  private:
41   PythonEngine();
42 };
43 
44 } // namespace torch::autograd::python
45