1 #pragma once
2
3 #include <torch/csrc/python_headers.h>
4 #include <memory>
5 #include <typeinfo>
6
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/autograd/function.h>
9 #include <torch/csrc/utils/object_ptr.h>
10
11 namespace torch::autograd {
12
13 struct THPCppFunction {
14 PyObject_HEAD std::shared_ptr<Node> cdata;
15 };
16
17 template <typename Ctor>
CppFunction_pynew(PyTypeObject * type,PyObject * args,PyObject * kwds)18 PyObject* CppFunction_pynew(
19 PyTypeObject* type,
20 PyObject* args,
21 PyObject* kwds) {
22 THPObjectPtr obj(type->tp_alloc(type, 0));
23 if (!obj)
24 return nullptr;
25 THPCppFunction* f = (THPCppFunction*)obj.get();
26 HANDLE_TH_ERRORS
27 new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
28 END_HANDLE_TH_ERRORS
29 if (!f->cdata) {
30 return nullptr;
31 }
32 return obj.release();
33 }
34
35 #define THP_FUNCTION_DEFAULT_METHODS \
36 {(char*)"_register_hook_dict", \
37 THPCppFunction_register_hook_dict, \
38 METH_O, \
39 nullptr}, \
40 {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
41 {(char*)"register_prehook", \
42 THPCppFunction_register_prehook, \
43 METH_O, \
44 nullptr}, \
45 {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr}, \
46 {(char*)"_sequence_nr", \
47 THPCppFunction_sequence_nr, \
48 METH_NOARGS, \
49 nullptr}, \
50 { \
51 (char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
52 }
53
54 #define THP_FUNCTION_DEFAULT_PROPERTIES \
55 {(char*)"next_functions", \
56 THPCppFunction_next_functions, \
57 nullptr, \
58 nullptr, \
59 nullptr}, \
60 {(char*)"requires_grad", \
61 THPCppFunction_requires_grad, \
62 nullptr, \
63 nullptr, \
64 nullptr}, \
65 {(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \
66 { \
67 (char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \
68 nullptr \
69 }
70
71 PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
72 PyObject* THPCppFunction_metadata(PyObject* self, void* _unused);
73 PyObject* THPCppFunction_requires_grad(PyObject* self, void* _unused);
74 PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
75 PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
76 PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
77
78 PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
79 PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
80 PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused);
81
82 PyTypeObject* _initFunctionPyTypeObject(
83 PyTypeObject& type,
84 const char* name,
85 PyGetSetDef* function_properties,
86 PyMethodDef* function_methods);
87
88 PyObject* registerFunctionHook(Node& fn, PyObject* hook);
89
90 PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
91
92 template <typename Ctor>
93 PyTypeObject* createForwardFunctionPyTypeObject(
94 PyTypeObject& type,
95 const char* name,
96 PyGetSetDef* function_properties = nullptr,
97 PyMethodDef* function_methods = nullptr) {
98 type.tp_new = &CppFunction_pynew<Ctor>;
99 return _initFunctionPyTypeObject(
100 type, name, function_properties, function_methods);
101 }
102
103 void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
104 PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
105
106 bool THPCppFunction_Check(PyObject* obj);
107
108 } // namespace torch::autograd
109