xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_cpp_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <memory>
5 #include <typeinfo>
6 
7 #include <torch/csrc/Exceptions.h>
8 #include <torch/csrc/autograd/function.h>
9 #include <torch/csrc/utils/object_ptr.h>
10 
11 namespace torch::autograd {
12 
13 struct THPCppFunction {
14   PyObject_HEAD std::shared_ptr<Node> cdata;
15 };
16 
17 template <typename Ctor>
CppFunction_pynew(PyTypeObject * type,PyObject * args,PyObject * kwds)18 PyObject* CppFunction_pynew(
19     PyTypeObject* type,
20     PyObject* args,
21     PyObject* kwds) {
22   THPObjectPtr obj(type->tp_alloc(type, 0));
23   if (!obj)
24     return nullptr;
25   THPCppFunction* f = (THPCppFunction*)obj.get();
26   HANDLE_TH_ERRORS
27   new (&f->cdata) std::shared_ptr<Node>(Ctor()(args));
28   END_HANDLE_TH_ERRORS
29   if (!f->cdata) {
30     return nullptr;
31   }
32   return obj.release();
33 }
34 
35 #define THP_FUNCTION_DEFAULT_METHODS                                           \
36   {(char*)"_register_hook_dict",                                               \
37    THPCppFunction_register_hook_dict,                                          \
38    METH_O,                                                                     \
39    nullptr},                                                                   \
40       {(char*)"register_hook", THPCppFunction_register_hook, METH_O, nullptr}, \
41       {(char*)"register_prehook",                                              \
42        THPCppFunction_register_prehook,                                        \
43        METH_O,                                                                 \
44        nullptr},                                                               \
45       {(char*)"name", THPCppFunction_name, METH_NOARGS, nullptr},              \
46       {(char*)"_sequence_nr",                                                  \
47        THPCppFunction_sequence_nr,                                             \
48        METH_NOARGS,                                                            \
49        nullptr},                                                               \
50   {                                                                            \
51     (char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
52   }
53 
54 #define THP_FUNCTION_DEFAULT_PROPERTIES                                        \
55   {(char*)"next_functions",                                                    \
56    THPCppFunction_next_functions,                                              \
57    nullptr,                                                                    \
58    nullptr,                                                                    \
59    nullptr},                                                                   \
60       {(char*)"requires_grad",                                                 \
61        THPCppFunction_requires_grad,                                           \
62        nullptr,                                                                \
63        nullptr,                                                                \
64        nullptr},                                                               \
65       {(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \
66   {                                                                            \
67     (char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \
68         nullptr                                                                \
69   }
70 
71 PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
72 PyObject* THPCppFunction_metadata(PyObject* self, void* _unused);
73 PyObject* THPCppFunction_requires_grad(PyObject* self, void* _unused);
74 PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var);
75 PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook);
76 PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
77 
78 PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
79 PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
80 PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused);
81 
82 PyTypeObject* _initFunctionPyTypeObject(
83     PyTypeObject& type,
84     const char* name,
85     PyGetSetDef* function_properties,
86     PyMethodDef* function_methods);
87 
88 PyObject* registerFunctionHook(Node& fn, PyObject* hook);
89 
90 PyObject* registerFunctionPreHook(Node& fn, PyObject* hook);
91 
92 template <typename Ctor>
93 PyTypeObject* createForwardFunctionPyTypeObject(
94     PyTypeObject& type,
95     const char* name,
96     PyGetSetDef* function_properties = nullptr,
97     PyMethodDef* function_methods = nullptr) {
98   type.tp_new = &CppFunction_pynew<Ctor>;
99   return _initFunctionPyTypeObject(
100       type, name, function_properties, function_methods);
101 }
102 
103 void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
104 PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
105 
106 bool THPCppFunction_Check(PyObject* obj);
107 
108 } // namespace torch::autograd
109