xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/six.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <torch/csrc/utils/structseq.h>
7 
8 namespace six {
9 
10 // Usually instances of PyStructSequence is also an instance of tuple
11 // but in some py2 environment it is not, so we have to manually check
12 // the name of the type to determine if it is a namedtupled returned
13 // by a pytorch operator.
14 
isStructSeq(pybind11::handle input)15 inline bool isStructSeq(pybind11::handle input) {
16   return pybind11::cast<std::string>(input.get_type().attr("__module__")) ==
17       "torch.return_types";
18 }
19 
isStructSeq(PyObject * obj)20 inline bool isStructSeq(PyObject* obj) {
21   return isStructSeq(pybind11::handle(obj));
22 }
23 
isTuple(pybind11::handle input)24 inline bool isTuple(pybind11::handle input) {
25   if (PyTuple_Check(input.ptr())) {
26     return true;
27   }
28   return false;
29 }
30 
isTuple(PyObject * obj)31 inline bool isTuple(PyObject* obj) {
32   return isTuple(pybind11::handle(obj));
33 }
34 
35 // maybeAsTuple: if the input is a structseq, then convert it to a tuple
36 //
37 // On Python 3, structseq is a subtype of tuple, so these APIs could be used
38 // directly. But on Python 2, structseq is not a subtype of tuple, so we need to
39 // manually create a new tuple object from structseq.
maybeAsTuple(PyStructSequence * obj)40 inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) {
41   Py_INCREF(obj);
42   return THPObjectPtr((PyObject*)obj);
43 }
44 
maybeAsTuple(PyObject * obj)45 inline THPObjectPtr maybeAsTuple(PyObject* obj) {
46   if (isStructSeq(obj))
47     return maybeAsTuple((PyStructSequence*)obj);
48   Py_INCREF(obj);
49   return THPObjectPtr(obj);
50 }
51 
52 } // namespace six
53