1 #pragma once 2 3 #include <pybind11/pybind11.h> 4 #include <torch/csrc/utils/object_ptr.h> 5 #include <torch/csrc/utils/pybind.h> 6 #include <torch/csrc/utils/structseq.h> 7 8 namespace six { 9 10 // Usually instances of PyStructSequence is also an instance of tuple 11 // but in some py2 environment it is not, so we have to manually check 12 // the name of the type to determine if it is a namedtupled returned 13 // by a pytorch operator. 14 isStructSeq(pybind11::handle input)15inline bool isStructSeq(pybind11::handle input) { 16 return pybind11::cast<std::string>(input.get_type().attr("__module__")) == 17 "torch.return_types"; 18 } 19 isStructSeq(PyObject * obj)20inline bool isStructSeq(PyObject* obj) { 21 return isStructSeq(pybind11::handle(obj)); 22 } 23 isTuple(pybind11::handle input)24inline bool isTuple(pybind11::handle input) { 25 if (PyTuple_Check(input.ptr())) { 26 return true; 27 } 28 return false; 29 } 30 isTuple(PyObject * obj)31inline bool isTuple(PyObject* obj) { 32 return isTuple(pybind11::handle(obj)); 33 } 34 35 // maybeAsTuple: if the input is a structseq, then convert it to a tuple 36 // 37 // On Python 3, structseq is a subtype of tuple, so these APIs could be used 38 // directly. But on Python 2, structseq is not a subtype of tuple, so we need to 39 // manually create a new tuple object from structseq. maybeAsTuple(PyStructSequence * obj)40inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) { 41 Py_INCREF(obj); 42 return THPObjectPtr((PyObject*)obj); 43 } 44 maybeAsTuple(PyObject * obj)45inline THPObjectPtr maybeAsTuple(PyObject* obj) { 46 if (isStructSeq(obj)) 47 return maybeAsTuple((PyStructSequence*)obj); 48 Py_INCREF(obj); 49 return THPObjectPtr(obj); 50 } 51 52 } // namespace six 53