xref: /aosp_15_r20/external/pytorch/tools/autograd/templates/python_nested_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3 
4 #include "torch/csrc/Device.h"
5 #include "torch/csrc/DynamicTypes.h"
6 #include "torch/csrc/Exceptions.h"
7 #include "torch/csrc/autograd/python_nested_functions.h"
8 #include "torch/csrc/autograd/generated/python_return_types.h"
9 #include "torch/csrc/autograd/python_variable.h"
10 #include "torch/csrc/autograd/utils/wrap_outputs.h"
11 #include "torch/csrc/autograd/utils/python_arg_parsing.h"
12 #include "torch/csrc/autograd/generated/variable_factories.h"
13 #include "torch/csrc/utils/out_types.h"
14 #include "torch/csrc/utils/pycfunction_helpers.h"
15 #include "torch/csrc/utils/python_arg_parser.h"
16 #include "torch/csrc/utils/structseq.h"
17 #include "torch/csrc/utils/device_lazy_init.h"
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 $ops_headers
23 #endif
24 
25 using at::Tensor;
26 using at::Device;
27 using at::Layout;
28 using at::Scalar;
29 using at::ScalarType;
30 using at::Backend;
31 using at::OptionalDeviceGuard;
32 using at::DeviceGuard;
33 using at::TensorOptions;
34 using at::IntArrayRef;
35 using at::OptionalIntArrayRef;
36 using at::Generator;
37 using at::TensorList;
38 using at::Dimname;
39 using at::DimnameList;
40 
41 using namespace torch::autograd::utils;
42 
43 namespace torch::autograd {
44 
45 // generated forward declarations start here
46 
47 ${py_forwards}
48 
49 static PyMethodDef nested_functions[] = {
50   {NULL, NULL, 0, NULL},
51   ${py_method_defs}
52   {NULL}
53 };
54 
55 static PyObject* THPNestedVariableFunctionsModule = NULL;
56 
initNestedFunctions(PyObject * module)57 void initNestedFunctions(PyObject* module) {
58   nested_functions[0] = get_nested_functions_manual()[0];
59   static struct PyModuleDef def = {
60      PyModuleDef_HEAD_INIT,
61      "torch._C._nested",
62      NULL,
63      -1,
64      nested_functions
65   };
66   PyObject* nested = PyModule_Create(&def);
67   THPNestedVariableFunctionsModule = nested;
68   if (!nested) {
69     throw python_error();
70   }
71   // steals a reference to nested
72   if (PyModule_AddObject(module, "_nested", nested) != 0) {
73     throw python_error();
74   }
75 }
76 
77 // generated methods start here
78 
79 ${py_methods}
80 
81 } // namespace torch::autograd
82