xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/pybind.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/pythoncapi_compat.h>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/jit_type_base.h>
8 #include <c10/util/irange.h>
9 #include <pybind11/pybind11.h>
10 #include <pybind11/stl.h>
11 
12 #include <torch/csrc/Device.h>
13 #include <torch/csrc/Dtype.h>
14 #include <torch/csrc/DynamicTypes.h>
15 #include <torch/csrc/Generator.h>
16 #include <torch/csrc/MemoryFormat.h>
17 #include <torch/csrc/Stream.h>
18 #include <torch/csrc/utils/tensor_memoryformats.h>
19 
20 namespace py = pybind11;
21 
22 // This makes intrusive_ptr to be available as a custom pybind11 holder type,
23 // see
24 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
25 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr<T>, true);
26 
27 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr<T>);
28 PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr<T>, true);
29 
30 namespace pybind11::detail {
31 
32 // torch.Tensor <-> at::Tensor conversions (without unwrapping)
33 template <>
34 struct TORCH_PYTHON_API type_caster<at::Tensor> {
35  public:
36   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
37   PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor"));
38 
39   bool load(handle src, bool);
40 
41   static handle cast(
42       const at::Tensor& src,
43       return_value_policy /* policy */,
44       handle /* parent */);
45 };
46 
47 // torch._StorageBase <-> at::Storage
48 template <>
49 struct type_caster<at::Storage> {
50  public:
51   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
52   PYBIND11_TYPE_CASTER(at::Storage, _("torch.StorageBase"));
53 
54   bool load(handle src, bool) {
55     PyObject* obj = src.ptr();
56     if (torch::isStorage(obj)) {
57       value = torch::createStorage(obj);
58       return true;
59     }
60     return false;
61   }
62 
63   static handle cast(
64       const at::Storage& src,
65       return_value_policy /* policy */,
66       handle /* parent */) {
67     return handle(torch::createPyObject(src));
68   }
69 };
70 
71 template <>
72 struct type_caster<at::Generator> {
73  public:
74   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
75   PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator"));
76 
77   bool load(handle src, bool) {
78     PyObject* obj = src.ptr();
79     if (THPGenerator_Check(obj)) {
80       value = reinterpret_cast<THPGenerator*>(obj)->cdata;
81       return true;
82     }
83     return false;
84   }
85 
86   static handle cast(
87       const at::Generator& src,
88       return_value_policy /* policy */,
89       handle /* parent */) {
90     return handle(THPGenerator_Wrap(src));
91   }
92 };
93 
94 template <>
95 struct TORCH_PYTHON_API type_caster<at::IntArrayRef> {
96  public:
97   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
98   PYBIND11_TYPE_CASTER(at::IntArrayRef, _("Tuple[int, ...]"));
99 
100   bool load(handle src, bool);
101   static handle cast(
102       at::IntArrayRef src,
103       return_value_policy /* policy */,
104       handle /* parent */);
105 
106  private:
107   std::vector<int64_t> v_value;
108 };
109 
110 template <>
111 struct TORCH_PYTHON_API type_caster<at::SymIntArrayRef> {
112  public:
113   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
114   PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("List[int]"));
115 
116   bool load(handle src, bool);
117   static handle cast(
118       at::SymIntArrayRef src,
119       return_value_policy /* policy */,
120       handle /* parent */);
121 
122  private:
123   std::vector<c10::SymInt> v_value;
124 };
125 
126 template <>
127 struct TORCH_PYTHON_API type_caster<at::ArrayRef<c10::SymNode>> {
128  public:
129   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
130   PYBIND11_TYPE_CASTER(at::ArrayRef<c10::SymNode>, _("List[SymNode]"));
131 
132   bool load(handle src, bool);
133   static handle cast(
134       at::ArrayRef<c10::SymNode> src,
135       return_value_policy /* policy */,
136       handle /* parent */);
137 
138  private:
139   std::vector<c10::SymNode> v_value;
140 };
141 
142 template <>
143 struct type_caster<at::MemoryFormat> {
144  public:
145   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
146   PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format"));
147 
148   bool load(handle src, bool) {
149     PyObject* obj = src.ptr();
150     if (THPMemoryFormat_Check(obj)) {
151       value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format;
152       return true;
153     }
154     return false;
155   }
156   static handle cast(
157       at::MemoryFormat src,
158       return_value_policy /* policy */,
159       handle /* parent */) {
160     return handle(Py_NewRef(torch::utils::getTHPMemoryFormat(src)));
161   }
162 };
163 
164 template <>
165 struct type_caster<at::Device> {
166  public:
167   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
168   PYBIND11_TYPE_CASTER(at::Device, _("torch.device"));
169 
170   // PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device
171   // cannot be default-initialized, we provide this constructor to explicitly
172   // initialize that field. The value doesn't matter as it will be overwritten
173   // after a successful call to load.
174   type_caster() : value(c10::kCPU) {}
175 
176   bool load(handle src, bool) {
177     PyObject* obj = src.ptr();
178     if (THPDevice_Check(obj)) {
179       value = reinterpret_cast<THPDevice*>(obj)->device;
180       return true;
181     }
182     return false;
183   }
184 
185   static handle cast(
186       const at::Device& src,
187       return_value_policy /* policy */,
188       handle /* parent */) {
189     return handle(THPDevice_New(src));
190   }
191 };
192 
193 template <>
194 struct type_caster<at::ScalarType> {
195  public:
196   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
197   PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
198 
199   // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
200   // cannot be default-initialized, we provide this constructor to explicitly
201   // initialize that field. The value doesn't matter as it will be overwritten
202   // after a successful call to load.
203   type_caster() : value(at::kFloat) {}
204 
205   bool load(handle src, bool) {
206     PyObject* obj = src.ptr();
207     if (THPDtype_Check(obj)) {
208       value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
209       return true;
210     }
211     return false;
212   }
213 
214   static handle cast(
215       const at::ScalarType& src,
216       return_value_policy /* policy */,
217       handle /* parent */) {
218     return Py_NewRef(torch::getTHPDtype(src));
219   }
220 };
221 
222 template <>
223 struct type_caster<c10::Stream> {
224  public:
225   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
226   PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream"));
227 
228   // PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream
229   // cannot be default-initialized, we provide this constructor to explicitly
230   // initialize that field. The value doesn't matter as it will be overwritten
231   // after a successful call to load.
232   type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {}
233 
234   bool load(handle src, bool) {
235     PyObject* obj = src.ptr();
236     if (THPStream_Check(obj)) {
237       value = c10::Stream::unpack3(
238           ((THPStream*)obj)->stream_id,
239           static_cast<c10::DeviceIndex>(((THPStream*)obj)->device_index),
240           static_cast<c10::DeviceType>(((THPStream*)obj)->device_type));
241       return true;
242     }
243     return false;
244   }
245 
246   static handle cast(
247       const c10::Stream& src,
248       return_value_policy /* policy */,
249       handle /* parent */) {
250     return handle(THPStream_Wrap(src));
251   }
252 };
253 
254 template <>
255 struct type_caster<c10::DispatchKey>
256     : public type_caster_base<c10::DispatchKey> {
257   using base = type_caster_base<c10::DispatchKey>;
258   c10::DispatchKey tmp{};
259 
260  public:
261   bool load(handle src, bool convert) {
262     if (base::load(src, convert)) {
263       return true;
264     } else if (py::isinstance(
265                    src, py::module_::import("builtins").attr("str"))) {
266       tmp = c10::parseDispatchKey(py::cast<std::string>(src));
267       value = &tmp;
268       return true;
269     }
270     return false;
271   }
272 
273   static handle cast(
274       c10::DispatchKey src,
275       return_value_policy policy,
276       handle parent) {
277     return base::cast(src, policy, parent);
278   }
279 };
280 
281 template <>
282 struct TORCH_PYTHON_API type_caster<c10::Scalar> {
283  public:
284   PYBIND11_TYPE_CASTER(
285       c10::Scalar,
286       _("Union[Number, torch.SymInt, torch.SymFloat, torch.SymBool]"));
287   bool load(py::handle src, bool);
288 
289   static py::handle cast(
290       const c10::Scalar& si,
291       return_value_policy /* policy */,
292       handle /* parent */);
293 };
294 
295 template <>
296 struct TORCH_PYTHON_API type_caster<c10::SymInt> {
297  public:
298   PYBIND11_TYPE_CASTER(c10::SymInt, _("Union[int, torch.SymInt]"));
299   bool load(py::handle src, bool);
300 
301   static py::handle cast(
302       const c10::SymInt& si,
303       return_value_policy /* policy */,
304       handle /* parent */);
305 };
306 
307 template <>
308 struct TORCH_PYTHON_API type_caster<c10::SymFloat> {
309  public:
310   PYBIND11_TYPE_CASTER(c10::SymFloat, _("float"));
311   bool load(py::handle src, bool);
312 
313   static py::handle cast(
314       const c10::SymFloat& si,
315       return_value_policy /* policy */,
316       handle /* parent */);
317 };
318 
319 template <>
320 struct TORCH_PYTHON_API type_caster<c10::SymBool> {
321  public:
322   PYBIND11_TYPE_CASTER(c10::SymBool, _("Union[bool, torch.SymBool]"));
323   bool load(py::handle src, bool);
324 
325   static py::handle cast(
326       const c10::SymBool& si,
327       return_value_policy /* policy */,
328       handle /* parent */);
329 };
330 
331 template <typename T>
332 struct type_caster<c10::complex<T>> {
333  public:
334   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
335   PYBIND11_TYPE_CASTER(c10::complex<T>, _("complex"));
336 
337   bool load(handle src, bool) {
338     PyObject* obj = src.ptr();
339 
340     // Refered from `THPUtils_unpackComplexDouble`
341     Py_complex py_complex = PyComplex_AsCComplex(obj);
342     if (py_complex.real == -1.0 && PyErr_Occurred()) {
343       return false;
344     }
345 
346     // Python's Complex is always double precision.
347     value = c10::complex<double>(py_complex.real, py_complex.imag);
348     return true;
349   }
350 
351   static handle cast(
352       const c10::complex<T>& complex,
353       return_value_policy /* policy */,
354       handle /* parent */) {
355     // Python only knows double precision complex.
356     return handle(PyComplex_FromDoubles(complex.real(), complex.imag()));
357   }
358 };
359 
360 } // namespace pybind11::detail
361 
362 namespace torch::impl {
363 
364 // Use this function if you have a C++ object that is used from both C++
365 // and Python contexts, and you need its GIL to be released when you
366 // destruct it in the Python context.
367 //
368 // This function is a valid shared_ptr destructor and can be used to
369 // conveniently allocate a shared_ptr to an object whose destructor will be run
370 // without the GIL.  Pass it as the second argument to shared_ptr, e.g.,
371 //
372 //    shared_ptr<T>(new T(), destroy_without_gil<T>)
373 //
374 // Attaching the GIL release logic to the holder pointer rather than the
375 // actual destructor of T is helpful when T is Python-agnostic and
376 // shouldn't refer to the PYthon API.
377 //
378 // Note there are limitations to the correctness of code that makes use of this.
379 // In particular, if a shared_ptr is constructed from C++ code without this
380 // destructor and then passed to pybind11, pybind11 will happily take ownership
381 // of the shared_ptr (and be willing to destruct it from a context where it is
382 // holding the GIL).  unique_ptr with a type branded deleter is less prone to
383 // this problem, because a stock deleter unique_ptr is not convertible with it.
384 // I plan to mitigate this problem by adding DEBUG-only asserts to the true C++
385 // destructors that the GIL is not held (using a virtual call to get to the
386 // Python interpreter); alternately, we could use a virtual call to simply
387 // ensure we release the GIL in the C++ destructor, however, this is a layering
388 // violation (why does code that is ostensibly Python agnostic calling into the
389 // GIL).
390 //
391 // Adapted from
392 // https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510
393 template <typename T>
394 inline void destroy_without_gil(T* ptr) {
395   // Because the ownership of a shared_ptr is diffuse, it's not possible to
396   // necessarily predict whether or not the last reference to an object will
397   // be destructed from Python or C++.  This means that in the destructor here,
398   // we don't necessarily know if we actually have the GIL or not; in fact,
399   // we don't even know if the Python interpreter still exists!  Thus, we have
400   // to test for it before releasing the GIL.
401   //
402   // PyGILState_Check is hopefully self explanatory.  But Py_IsInitialized or
403   // _PyIsFinalizing?  Both get set at the same time during the Python
404   // destruction process:
405   // https://github.com/python/cpython/blob/d92513390a1a0da781bb08c284136f4d7abea36d/Python/pylifecycle.c#L1716-L1717
406   // so the operant question is whether or not you want to release the GIL after
407   // finalization has completed (and there is just no Python interpreter).
408   // Clearly there is no need to release GIL in that state, so we want
409   // Py_IsInitialized.
410   if (Py_IsInitialized() && PyGILState_Check()) {
411     pybind11::gil_scoped_release nogil;
412     delete ptr;
413   } else {
414     delete ptr;
415   }
416 }
417 
418 } // namespace torch::impl
419