1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/pythoncapi_compat.h>
6
7 #include <ATen/core/function_schema.h>
8 #include <pybind11/pybind11.h>
9 #include <torch/csrc/Exceptions.h>
10 #include <torch/csrc/Export.h>
11 #include <torch/csrc/autograd/variable.h>
12 #include <torch/csrc/utils/pybind.h>
13
14 namespace py = pybind11;
15
16 // Python object that backs torch.autograd.Variable
17 struct THPVariable {
18 PyObject_HEAD;
19 // Payload
20 c10::MaybeOwned<at::Tensor> cdata;
21 // Hooks to be run on backwards pass (corresponds to Python attr
22 // '_backwards_hooks', set by 'register_hook')
23 PyObject* backward_hooks = nullptr;
24 // Hooks to be run in the backwards pass after accumulate grad,
25 // i.e., after the .grad has been set (corresponds to Python attr
26 // '_post_accumulate_grad_hooks', set by 'register_post_accumulate_grad_hook')
27 PyObject* post_accumulate_grad_hooks = nullptr;
28 };
29
30 TORCH_PYTHON_API void registerPythonTensorClass(
31 const std::string& device,
32 PyObject* python_tensor_class);
33
34 TORCH_PYTHON_API void activateGPUTrace();
35
36 TORCH_PYTHON_API extern PyObject* THPVariableClass;
37 TORCH_PYTHON_API extern PyObject* ParameterClass;
38
39 bool THPVariable_initModule(PyObject* module);
40 TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var);
41
THPVariable_CheckTypeExact(PyTypeObject * tp)42 inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
43 // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
44 // (A subclass could have different semantics.) The one exception is
45 // Parameter, which is used for Python bookkeeping but is equivalent to
46 // Tensor as far as C++ is concerned.
47 return (
48 tp == (PyTypeObject*)THPVariableClass ||
49 tp == (PyTypeObject*)ParameterClass);
50 }
51
THPVariable_CheckExact(PyObject * obj)52 inline bool THPVariable_CheckExact(PyObject* obj) {
53 return THPVariable_CheckTypeExact(Py_TYPE(obj));
54 }
55
THPVariable_Check(PyObject * obj)56 inline bool THPVariable_Check(PyObject* obj) {
57 if (!THPVariableClass)
58 return false;
59
60 // Fast path
61 if (THPVariable_CheckExact(obj)) {
62 return true;
63 }
64
65 const auto result = PyObject_IsInstance(obj, THPVariableClass);
66 if (result == -1)
67 throw python_error();
68 return result;
69 }
70
THPVariable_Unpack(THPVariable * var)71 inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
72 return *var->cdata;
73 }
74
THPVariable_Unpack(PyObject * obj)75 inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
76 return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
77 }
78
79 std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
80 const c10::OperatorHandle& op,
81 const std::vector<c10::IValue>& arguments);
82
83 void pushPyOutToStack(
84 const c10::OperatorHandle& op,
85 torch::jit::Stack* stack,
86 py::object out,
87 const char* msg);
88
THPVariable_WrapList(const torch::autograd::variable_list & inputs)89 inline PyObject* THPVariable_WrapList(
90 const torch::autograd::variable_list& inputs) {
91 PyObject* pyinput = PyList_New(static_cast<Py_ssize_t>(inputs.size()));
92 for (const auto i : c10::irange(inputs.size())) {
93 PyList_SET_ITEM(pyinput, i, THPVariable_Wrap(inputs[i]));
94 }
95 return pyinput;
96 }
97
THPVariable_UnpackList(PyObject * pyresult)98 inline torch::autograd::variable_list THPVariable_UnpackList(
99 PyObject* pyresult) {
100 TORCH_CHECK(PyList_CheckExact(pyresult));
101 auto result_len = PyList_GET_SIZE(pyresult);
102 torch::autograd::variable_list result;
103 result.reserve(result_len);
104 for (const auto i : c10::irange(result_len)) {
105 PyObject* item = PyList_GET_ITEM(pyresult, i);
106 if (!Py_IsNone(item)) {
107 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(THPVariable_Check(item));
108 result.emplace_back(THPVariable_Unpack(item));
109 } else {
110 result.emplace_back();
111 }
112 }
113 return result;
114 }
115