xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_variable.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/pythoncapi_compat.h>
6 
7 #include <ATen/core/function_schema.h>
8 #include <pybind11/pybind11.h>
9 #include <torch/csrc/Exceptions.h>
10 #include <torch/csrc/Export.h>
11 #include <torch/csrc/autograd/variable.h>
12 #include <torch/csrc/utils/pybind.h>
13 
14 namespace py = pybind11;
15 
16 // Python object that backs torch.autograd.Variable
17 struct THPVariable {
18   PyObject_HEAD;
19   // Payload
20   c10::MaybeOwned<at::Tensor> cdata;
21   // Hooks to be run on backwards pass (corresponds to Python attr
22   // '_backwards_hooks', set by 'register_hook')
23   PyObject* backward_hooks = nullptr;
24   // Hooks to be run in the backwards pass after accumulate grad,
25   // i.e., after the .grad has been set (corresponds to Python attr
26   // '_post_accumulate_grad_hooks', set by 'register_post_accumulate_grad_hook')
27   PyObject* post_accumulate_grad_hooks = nullptr;
28 };
29 
30 TORCH_PYTHON_API void registerPythonTensorClass(
31     const std::string& device,
32     PyObject* python_tensor_class);
33 
34 TORCH_PYTHON_API void activateGPUTrace();
35 
36 TORCH_PYTHON_API extern PyObject* THPVariableClass;
37 TORCH_PYTHON_API extern PyObject* ParameterClass;
38 
39 bool THPVariable_initModule(PyObject* module);
40 TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var);
41 
THPVariable_CheckTypeExact(PyTypeObject * tp)42 inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
43   // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
44   // (A subclass could have different semantics.) The one exception is
45   // Parameter, which is used for Python bookkeeping but is equivalent to
46   // Tensor as far as C++ is concerned.
47   return (
48       tp == (PyTypeObject*)THPVariableClass ||
49       tp == (PyTypeObject*)ParameterClass);
50 }
51 
THPVariable_CheckExact(PyObject * obj)52 inline bool THPVariable_CheckExact(PyObject* obj) {
53   return THPVariable_CheckTypeExact(Py_TYPE(obj));
54 }
55 
THPVariable_Check(PyObject * obj)56 inline bool THPVariable_Check(PyObject* obj) {
57   if (!THPVariableClass)
58     return false;
59 
60   // Fast path
61   if (THPVariable_CheckExact(obj)) {
62     return true;
63   }
64 
65   const auto result = PyObject_IsInstance(obj, THPVariableClass);
66   if (result == -1)
67     throw python_error();
68   return result;
69 }
70 
THPVariable_Unpack(THPVariable * var)71 inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
72   return *var->cdata;
73 }
74 
THPVariable_Unpack(PyObject * obj)75 inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
76   return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
77 }
78 
79 std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
80     const c10::OperatorHandle& op,
81     const std::vector<c10::IValue>& arguments);
82 
83 void pushPyOutToStack(
84     const c10::OperatorHandle& op,
85     torch::jit::Stack* stack,
86     py::object out,
87     const char* msg);
88 
THPVariable_WrapList(const torch::autograd::variable_list & inputs)89 inline PyObject* THPVariable_WrapList(
90     const torch::autograd::variable_list& inputs) {
91   PyObject* pyinput = PyList_New(static_cast<Py_ssize_t>(inputs.size()));
92   for (const auto i : c10::irange(inputs.size())) {
93     PyList_SET_ITEM(pyinput, i, THPVariable_Wrap(inputs[i]));
94   }
95   return pyinput;
96 }
97 
THPVariable_UnpackList(PyObject * pyresult)98 inline torch::autograd::variable_list THPVariable_UnpackList(
99     PyObject* pyresult) {
100   TORCH_CHECK(PyList_CheckExact(pyresult));
101   auto result_len = PyList_GET_SIZE(pyresult);
102   torch::autograd::variable_list result;
103   result.reserve(result_len);
104   for (const auto i : c10::irange(result_len)) {
105     PyObject* item = PyList_GET_ITEM(pyresult, i);
106     if (!Py_IsNone(item)) {
107       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(THPVariable_Check(item));
108       result.emplace_back(THPVariable_Unpack(item));
109     } else {
110       result.emplace_back();
111     }
112   }
113   return result;
114 }
115