1 #pragma once 2 3 #include <torch/csrc/python_headers.h> 4 5 #include <torch/csrc/autograd/engine.h> 6 #include <torch/csrc/autograd/function.h> 7 8 bool THPEngine_initModule(PyObject* module); 9 10 namespace torch::autograd::python { 11 12 struct PythonEngine : public Engine { 13 static Engine& get_python_engine(); 14 ~PythonEngine() override; 15 void thread_init( 16 int device, 17 const std::shared_ptr<ReadyQueue>& ready_queue, 18 bool should_increment) override; 19 void thread_on_exception( 20 const std::shared_ptr<GraphTask>& graph_task, 21 const std::shared_ptr<Node>& fn, 22 std::exception& e) override; 23 variable_list execute( 24 const edge_list& roots, 25 const variable_list& inputs, 26 bool keep_graph, 27 bool create_graph, 28 bool accumulate_grad, 29 const edge_list& outputs = {}) override; 30 31 c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( 32 const std::shared_ptr<GraphTask>& graph_task, 33 std::shared_ptr<Node> graph_root, 34 InputBuffer&& input_buffer) override; 35 36 std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() override; 37 std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() 38 override; 39 40 private: 41 PythonEngine(); 42 }; 43 44 } // namespace torch::autograd::python 45