xref: /aosp_15_r20/external/pytorch/tools/autograd/templates/python_special_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3 
4 #include "torch/csrc/Device.h"
5 #include "torch/csrc/DynamicTypes.h"
6 #include "torch/csrc/Exceptions.h"
7 #include "torch/csrc/autograd/python_special_functions.h"
8 #include "torch/csrc/autograd/generated/python_return_types.h"
9 #include "torch/csrc/autograd/python_variable.h"
10 #include "torch/csrc/autograd/utils/wrap_outputs.h"
11 #include "torch/csrc/autograd/utils/python_arg_parsing.h"
12 #include "torch/csrc/autograd/generated/variable_factories.h"
13 #include "torch/csrc/utils/out_types.h"
14 #include "torch/csrc/utils/pycfunction_helpers.h"
15 #include "torch/csrc/utils/python_arg_parser.h"
16 #include "torch/csrc/utils/structseq.h"
17 #include "torch/csrc/utils/device_lazy_init.h"
18 
19 #ifndef AT_PER_OPERATOR_HEADERS
20 #include <ATen/Functions.h>
21 #else
22 $ops_headers
23 #endif
24 
25 using at::Tensor;
26 using at::Device;
27 using at::Layout;
28 using at::Scalar;
29 using at::ScalarType;
30 using at::Backend;
31 using at::OptionalDeviceGuard;
32 using at::DeviceGuard;
33 using at::TensorOptions;
34 using at::IntArrayRef;
35 using at::Generator;
36 using at::TensorList;
37 using at::Dimname;
38 using at::DimnameList;
39 
40 using torch::utils::check_out_type_matches;
41 using namespace torch::autograd::utils;
42 
43 namespace torch::autograd {
44 
45 // generated forward declarations start here
46 
47 ${py_forwards}
48 
49 static PyMethodDef special_functions[] = {
50   ${py_method_defs}
51   {NULL}
52 };
53 
54 static PyObject* THPSpecialVariableFunctionsModule = NULL;
55 
initSpecialFunctions(PyObject * module)56 void initSpecialFunctions(PyObject* module) {
57   static struct PyModuleDef def = {
58      PyModuleDef_HEAD_INIT,
59      "torch._C._special",
60      NULL,
61      -1,
62      special_functions
63   };
64   PyObject* special = PyModule_Create(&def);
65   THPSpecialVariableFunctionsModule = special;
66   if (!special) {
67     throw python_error();
68   }
69   // steals a reference to special
70   if (PyModule_AddObject(module, "_special", special) != 0) {
71     throw python_error();
72   }
73 }
74 
75 // generated methods start here
76 
77 ${py_methods}
78 
79 } // namespace torch::autograd
80