xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/pybind.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/pybind.h>
2 #include <torch/csrc/utils/python_arg_parser.h>
3 #include <torch/csrc/utils/python_symnode.h>
4 
5 namespace pybind11::detail {
6 
load(py::handle src,bool)7 bool type_caster<c10::SymInt>::load(py::handle src, bool) {
8   if (torch::is_symint(src)) {
9     auto node = src.attr("node");
10     if (py::isinstance<c10::SymNodeImpl>(node)) {
11       value = c10::SymInt(py::cast<c10::SymNode>(node));
12       return true;
13     }
14 
15     value = c10::SymInt(static_cast<c10::SymNode>(
16         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
17     return true;
18   }
19 
20   auto raw_obj = src.ptr();
21 
22   if (THPVariable_Check(raw_obj)) {
23     auto& var = THPVariable_Unpack(raw_obj);
24     if (var.numel() == 1 &&
25         at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
26       auto scalar = var.item();
27       TORCH_INTERNAL_ASSERT(scalar.isIntegral(/*include bool*/ false));
28       value = scalar.toSymInt();
29       return true;
30     }
31   }
32 
33   if (THPUtils_checkIndex(raw_obj)) {
34     value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
35     return true;
36   }
37   return false;
38 }
39 
cast(const c10::SymInt & si,return_value_policy,handle)40 py::handle type_caster<c10::SymInt>::cast(
41     const c10::SymInt& si,
42     return_value_policy /* policy */,
43     handle /* parent */) {
44   if (si.is_symbolic()) {
45     auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(
46         si.toSymNodeImplUnowned());
47     if (py_node) {
48       // Return the Python directly (unwrap)
49       return torch::get_symint_class()(py_node->getPyObj()).release();
50     } else {
51       // Wrap the C++ into Python
52       auto inner = py::cast(si.toSymNode());
53       if (!inner) {
54         throw python_error();
55       }
56       return torch::get_symint_class()(inner).release();
57     }
58   } else {
59     auto m = si.maybe_as_int();
60     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
61     return py::cast(*m).release();
62   }
63 }
64 
load(py::handle src,bool)65 bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
66   if (torch::is_symfloat(src)) {
67     value = c10::SymFloat(static_cast<c10::SymNode>(
68         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
69     return true;
70   }
71 
72   auto raw_obj = src.ptr();
73   if (THPUtils_checkDouble(raw_obj)) {
74     value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
75     return true;
76   }
77   return false;
78 }
79 
cast(const c10::SymFloat & si,return_value_policy,handle)80 py::handle type_caster<c10::SymFloat>::cast(
81     const c10::SymFloat& si,
82     return_value_policy /* policy */,
83     handle /* parent */) {
84   if (si.is_symbolic()) {
85     // TODO: generalize this to work with C++ backed class
86     auto* py_node =
87         dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
88     TORCH_INTERNAL_ASSERT(py_node);
89     return torch::get_symfloat_class()(py_node->getPyObj()).release();
90   } else {
91     return py::cast(si.as_float_unchecked()).release();
92   }
93 }
94 
load(py::handle src,bool)95 bool type_caster<c10::SymBool>::load(py::handle src, bool) {
96   if (torch::is_symbool(src)) {
97     value = c10::SymBool(static_cast<c10::SymNode>(
98         c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
99     return true;
100   }
101 
102   auto raw_obj = src.ptr();
103   if (THPUtils_checkBool(raw_obj)) {
104     value = c10::SymBool{THPUtils_unpackBool(raw_obj)};
105     return true;
106   }
107   return false;
108 }
109 
cast(const c10::SymBool & si,return_value_policy,handle)110 py::handle type_caster<c10::SymBool>::cast(
111     const c10::SymBool& si,
112     return_value_policy /* policy */,
113     handle /* parent */) {
114   if (auto m = si.maybe_as_bool()) {
115     return py::cast(*m).release();
116   } else {
117     // TODO: generalize this to work with C++ backed class
118     auto* py_node =
119         dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
120     TORCH_INTERNAL_ASSERT(py_node);
121     return torch::get_symbool_class()(py_node->getPyObj()).release();
122   }
123 }
124 
load(py::handle src,bool)125 bool type_caster<c10::Scalar>::load(py::handle src, bool) {
126   TORCH_INTERNAL_ASSERT(
127       0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
128 }
129 
cast(const c10::Scalar & scalar,return_value_policy,handle)130 py::handle type_caster<c10::Scalar>::cast(
131     const c10::Scalar& scalar,
132     return_value_policy /* policy */,
133     handle /* parent */) {
134   if (scalar.isIntegral(/*includeBool*/ false)) {
135     // We have to be careful here; we cannot unconditionally route through
136     // SymInt because integer data from Tensors can easily be MIN_INT or
137     // very negative, which conflicts with the allocated range.
138     if (scalar.isSymbolic()) {
139       return py::cast(scalar.toSymInt()).release();
140     } else {
141       if (scalar.type() == at::ScalarType::UInt64) {
142         return py::cast(scalar.toUInt64()).release();
143       } else {
144         return py::cast(scalar.toLong()).release();
145       }
146     }
147   } else if (scalar.isFloatingPoint()) {
148     // This isn't strictly necessary but we add it for symmetry
149     if (scalar.isSymbolic()) {
150       return py::cast(scalar.toSymFloat()).release();
151     } else {
152       return py::cast(scalar.toDouble()).release();
153     }
154   } else if (scalar.isBoolean()) {
155     if (scalar.isSymbolic()) {
156       return py::cast(scalar.toSymBool()).release();
157     }
158     return py::cast(scalar.toBool()).release();
159   } else if (scalar.isComplex()) {
160     return py::cast(scalar.toComplexDouble()).release();
161   } else {
162     TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
163   }
164 }
165 
166 } // namespace pybind11::detail
167