1 #include <torch/csrc/utils/pybind.h>
2 #include <torch/csrc/utils/python_arg_parser.h>
3 #include <torch/csrc/utils/python_symnode.h>
4
5 namespace pybind11::detail {
6
load(py::handle src,bool)7 bool type_caster<c10::SymInt>::load(py::handle src, bool) {
8 if (torch::is_symint(src)) {
9 auto node = src.attr("node");
10 if (py::isinstance<c10::SymNodeImpl>(node)) {
11 value = c10::SymInt(py::cast<c10::SymNode>(node));
12 return true;
13 }
14
15 value = c10::SymInt(static_cast<c10::SymNode>(
16 c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
17 return true;
18 }
19
20 auto raw_obj = src.ptr();
21
22 if (THPVariable_Check(raw_obj)) {
23 auto& var = THPVariable_Unpack(raw_obj);
24 if (var.numel() == 1 &&
25 at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
26 auto scalar = var.item();
27 TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
28 value = scalar.toSymInt();
29 return true;
30 }
31 }
32
33 if (THPUtils_checkIndex(raw_obj)) {
34 value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
35 return true;
36 }
37 return false;
38 }
39
cast(const c10::SymInt & si,return_value_policy,handle)40 py::handle type_caster<c10::SymInt>::cast(
41 const c10::SymInt& si,
42 return_value_policy /* policy */,
43 handle /* parent */) {
44 if (si.is_symbolic()) {
45 auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
46 si.toSymNodeImplUnowned());
47 if (py_node) {
48 // Return the Python directly (unwrap)
49 return torch::get_symint_class()(py_node->getPyObj()).release();
50 } else {
51 // Wrap the C++ into Python
52 auto inner = py::cast(si.toSymNode());
53 if (!inner) {
54 throw python_error();
55 }
56 return torch::get_symint_class()(inner).release();
57 }
58 } else {
59 auto m = si.maybe_as_int();
60 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
61 return py::cast(*m).release();
62 }
63 }
64
load(py::handle src,bool)65 bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
66 if (torch::is_symfloat(src)) {
67 value = c10::SymFloat(static_cast<c10::SymNode>(
68 c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
69 return true;
70 }
71
72 auto raw_obj = src.ptr();
73 if (THPUtils_checkDouble(raw_obj)) {
74 value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
75 return true;
76 }
77 return false;
78 }
79
cast(const c10::SymFloat & si,return_value_policy,handle)80 py::handle type_caster<c10::SymFloat>::cast(
81 const c10::SymFloat& si,
82 return_value_policy /* policy */,
83 handle /* parent */) {
84 if (si.is_symbolic()) {
85 // TODO: generalize this to work with C++ backed class
86 auto* py_node =
87 dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
88 TORCH_INTERNAL_ASSERT(py_node);
89 return torch::get_symfloat_class()(py_node->getPyObj()).release();
90 } else {
91 return py::cast(si.as_float_unchecked()).release();
92 }
93 }
94
load(py::handle src,bool)95 bool type_caster<c10::SymBool>::load(py::handle src, bool) {
96 if (torch::is_symbool(src)) {
97 value = c10::SymBool(static_cast<c10::SymNode>(
98 c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
99 return true;
100 }
101
102 auto raw_obj = src.ptr();
103 if (THPUtils_checkBool(raw_obj)) {
104 value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
105 return true;
106 }
107 return false;
108 }
109
cast(const c10::SymBool & si,return_value_policy,handle)110 py::handle type_caster<c10::SymBool>::cast(
111 const c10::SymBool& si,
112 return_value_policy /* policy */,
113 handle /* parent */) {
114 if (auto m = si.maybe_as_bool()) {
115 return py::cast(*m).release();
116 } else {
117 // TODO: generalize this to work with C++ backed class
118 auto* py_node =
119 dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
120 TORCH_INTERNAL_ASSERT(py_node);
121 return torch::get_symbool_class()(py_node->getPyObj()).release();
122 }
123 }
124
load(py::handle src,bool)125 bool type_caster<c10::Scalar>::load(py::handle src, bool) {
126 TORCH_INTERNAL_ASSERT(
127 0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
128 }
129
cast(const c10::Scalar & scalar,return_value_policy,handle)130 py::handle type_caster<c10::Scalar>::cast(
131 const c10::Scalar& scalar,
132 return_value_policy /* policy */,
133 handle /* parent */) {
134 if (scalar.isIntegral(/*includeBool*/ false)) {
135 // We have to be careful here; we cannot unconditionally route through
136 // SymInt because integer data from Tensors can easily be MIN_INT or
137 // very negative, which conflicts with the allocated range.
138 if (scalar.isSymbolic()) {
139 return py::cast(scalar.toSymInt()).release();
140 } else {
141 if (scalar.type() == at::ScalarType::UInt64) {
142 return py::cast(scalar.toUInt64()).release();
143 } else {
144 return py::cast(scalar.toLong()).release();
145 }
146 }
147 } else if (scalar.isFloatingPoint()) {
148 // This isn't strictly necessary but we add it for symmetry
149 if (scalar.isSymbolic()) {
150 return py::cast(scalar.toSymFloat()).release();
151 } else {
152 return py::cast(scalar.toDouble()).release();
153 }
154 } else if (scalar.isBoolean()) {
155 if (scalar.isSymbolic()) {
156 return py::cast(scalar.toSymBool()).release();
157 }
158 return py::cast(scalar.toBool()).release();
159 } else if (scalar.isComplex()) {
160 return py::cast(scalar.toComplexDouble()).release();
161 } else {
162 TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
163 }
164 }
165
166 } // namespace pybind11::detail
167