xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/custom_function.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/saved_variable.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/object_ptr.h>
11 
12 #include <c10/core/DeviceGuard.h>
13 #include <optional>
14 
15 #include <memory>
16 #include <optional>
17 #include <vector>
18 
19 namespace torch::jit {
20 struct Graph;
21 }
22 
23 namespace torch::autograd {
24 
25 // A Function which is implemented by a Python object (i.e., a THPFunction).
26 // Calls to 'apply' are forwarded to the Python method implementation.
27 struct PyNode : public Node {
PyNodePyNode28   PyNode(THPObjectPtr obj) : obj(obj.release()) {}
29 
30   PyObject* to_py_args(
31       const variable_list& inputs,
32       at::OptionalDeviceGuard* device_guard);
33   variable_list to_variable_list(
34       const PyObject* r,
35       const std::vector<bool>& is_variable_input);
36 
37   variable_list apply(variable_list&& inputs) override;
38   variable_list defer_to_dynamo(
39       variable_list&& inputs,
40       std::optional<PyObject*> compiler);
41 
42   void release_variables() override;
43   std::string name() const override;
44   bool is_traceable() override;
45 
46   void compiled_args(CompiledNodeArgs& args) override;
47   variable_list apply_with_saved(
48       const variable_list& inputs,
49       SwapSavedVariables& saved) override;
50 
51   bool compiled_autograd_should_lift() const;
52 
53   // THPFunction this Function is wrapping.  Owning!
54   PyObject* obj;
55 
56   // The AutogradCompilerCall::hooks idx corresponding to this node's backward
57   std::optional<int> _backward_idx;
58 
59   // The AutogradCompilerCall::hooks idx corresponding to this node's
60   // backward_state
61   std::optional<int> _backward_state_idx;
62 
63   // NOLINTNEXTLINE(bugprone-exception-escape)
~PyNodePyNode64   ~PyNode() override {
65     // Can't use THPObjectPtr as a field in this class; destructor won't take
66     // out GIL!  When I forgot to do this by hand
67     // TestAutograd.test_inplace_view_python called me out about it.
68     // If python is already dead, leak the wrapped python objects
69     if (Py_IsInitialized()) {
70       pybind11::gil_scoped_acquire gil;
71       Py_DECREF(obj);
72     }
73   }
74 };
75 
76 /**
77  * Cast an object into a tuple, if it is not a tuple already. Returns true
78  * if the original object was not a tuple.
79  */
ensure_tuple(THPObjectPtr & obj)80 inline bool ensure_tuple(THPObjectPtr& obj) {
81   if (PyTuple_Check(obj.get()))
82     return false;
83 
84   PyObject* tuple = PyTuple_New(1);
85   if (!tuple)
86     throw python_error();
87   PyTuple_SET_ITEM(tuple, 0, obj.release());
88   obj = tuple;
89   return true;
90 }
91 
92 } // namespace torch::autograd
93 
94 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
95 struct THPFunction {
96   PyObject_HEAD
97 
98       PyObject* needs_input_grad;
99 
100   // Python tuple of tensors whose variables we should save.  Set
101   // by Python with 'save_for_backward'.  If nullptr, no tensors were
102   // saved.
103   PyObject* to_save;
104   // Python tuple of tensors which are not differentiable.  Set by
105   // Python with 'mark_non_differentiable'.  If nullptr, no tensors were
106   // non-differentiable.
107   PyObject* non_differentiable;
108   // Python tuple of tensors which had inplace updates in the forward()
109   // pass.  Set by Python with 'mark_dirty'.  If nullptr, no tensors were
110   // modified inplace.
111   PyObject* dirty_tensors;
112 
113   // boolean indicating whether to materialize undefined output grad tensors
114   // into tensors full of zeros. Set by Python with 'set_materialize_grads'.
115   // Default is true.
116   bool materialize_grads;
117 
118   // boolean indicating whether to materialize output grad tensors
119   // corresponding to non-differentiable outputs. Normally, someone would
120   // already get this behavior by switching off materialize_grads,
121   // but there are certain use cases where that is not feasible:
122   // https://github.com/pytorch/pytorch/pull/98659#pullrequestreview-1376822560
123   bool materialize_non_diff_grads;
124 
125   // This is enabled by compiled autograd as a way to signal to AotAutograd it
126   // should call the original FX graph rather than compiling.
127   bool compiled_autograd_tracing;
128   PyObject* compiled_autograd_backward_state;
129   std::vector<c10::SymInt> compiled_autograd_symints;
130 
131   std::vector<torch::autograd::VariableInfo> output_info;
132   std::vector<torch::autograd::VariableInfo> input_info;
133   std::vector<torch::autograd::SavedVariable> saved_variables;
134   // For each input, true if the input is a THPVariable
135   std::vector<bool> is_variable_input;
136   char has_freed_buffers;
137 
138   PyObject* saved_for_forward;
139   // The actual PyNode (in the autograd graph) that this data was
140   // saved for.  This field may be NULL (because a user can construct
141   // a THPFunction directly from Python), but when this field is non-NULL,
142   // it is guaranteed that cdata.lock()->obj == this
143   //
144   // In most ordinary use, this field should always be non-NULL; e.g.,
145   // when we allocate a THPFunction because we are running Node.apply,
146   // after constructing a THPFunction, we immediately allocate a PyNode
147   // for it.  We can't enforce this directly in the constructor of
148   // THPFunction though, because there's no way to keep it live long enough
149   // to save an owning reference to PyNode into the grad_fn of a Variable.
150   std::weak_ptr<torch::autograd::PyNode> cdata;
151 };
152 
153 bool THPFunction_initModule(PyObject* module);
154 extern PyTypeObject THPFunctionType;
155 extern PyObject* THPFunctionClass;
156 extern PyObject* THPGradientEdgeClass;
157 
THPFunction_Check(PyObject * obj)158 inline bool THPFunction_Check(PyObject* obj) {
159   return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
160 }
161