xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_anomaly_mode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/autograd/anomaly_mode.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace torch::autograd {
9 
10 struct PyAnomalyMetadata : public AnomalyMetadata {
11   static constexpr const char* ANOMALY_TRACE_KEY = "traceback_";
12   static constexpr const char* ANOMALY_PARENT_KEY = "parent_";
13 
PyAnomalyMetadataPyAnomalyMetadata14   PyAnomalyMetadata() {
15     pybind11::gil_scoped_acquire gil;
16     // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
17     dict_ = PyDict_New();
18   }
~PyAnomalyMetadataPyAnomalyMetadata19   ~PyAnomalyMetadata() override {
20     // If python is already dead, leak the wrapped python objects
21     if (Py_IsInitialized()) {
22       pybind11::gil_scoped_acquire gil;
23       Py_DECREF(dict_);
24     }
25   }
26   void store_stack() override;
27   void print_stack(const std::string& current_node_name) override;
28   void assign_parent(const std::shared_ptr<Node>& parent_node) override;
29 
dictPyAnomalyMetadata30   PyObject* dict() {
31     return dict_;
32   }
33 
34  private:
35   PyObject* dict_{nullptr};
36 };
37 void _print_stack(
38     PyObject* trace_stack,
39     const std::string& current_node_name,
40     bool is_parent);
41 
42 } // namespace torch::autograd
43