1 #pragma once
2
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <stdexcept>
7 #include <string>
8
9 // Utilities for handling Python strings. Note that PyString, when defined, is
10 // the same as PyBytes.
11
12 // Returns true if obj is a bytes/str or unicode object
13 // As of Python 3.6, this does not require the GIL
THPUtils_checkString(PyObject * obj)14 inline bool THPUtils_checkString(PyObject* obj) {
15 return PyBytes_Check(obj) || PyUnicode_Check(obj);
16 }
17
18 // Unpacks PyBytes (PyString) or PyUnicode as std::string
19 // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8.
20 // NOTE: this method requires the GIL
THPUtils_unpackString(PyObject * obj)21 inline std::string THPUtils_unpackString(PyObject* obj) {
22 if (PyBytes_Check(obj)) {
23 size_t size = PyBytes_GET_SIZE(obj);
24 return std::string(PyBytes_AS_STRING(obj), size);
25 }
26 if (PyUnicode_Check(obj)) {
27 Py_ssize_t size = 0;
28 const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
29 if (!data) {
30 throw std::runtime_error("error unpacking string as utf-8");
31 }
32 return std::string(data, (size_t)size);
33 }
34 throw std::runtime_error("unpackString: expected bytes or unicode object");
35 }
36
37 // Unpacks PyBytes (PyString) or PyUnicode as c10::string_view
38 // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8.
39 // NOTE: If `obj` is destroyed, then the non-owning c10::string_view will
40 // become invalid. If the string needs to be accessed at any point after
41 // `obj` is destroyed, then the c10::string_view should be copied into
42 // a std::string, or another owning object, and kept alive. For an example,
43 // look at how IValue and autograd nodes handle c10::string_view arguments.
44 // NOTE: this method requires the GIL
THPUtils_unpackStringView(PyObject * obj)45 inline c10::string_view THPUtils_unpackStringView(PyObject* obj) {
46 if (PyBytes_Check(obj)) {
47 size_t size = PyBytes_GET_SIZE(obj);
48 return c10::string_view(PyBytes_AS_STRING(obj), size);
49 }
50 if (PyUnicode_Check(obj)) {
51 Py_ssize_t size = 0;
52 const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
53 if (!data) {
54 throw std::runtime_error("error unpacking string as utf-8");
55 }
56 return c10::string_view(data, (size_t)size);
57 }
58 throw std::runtime_error("unpackString: expected bytes or unicode object");
59 }
60
THPUtils_packString(const char * str)61 inline PyObject* THPUtils_packString(const char* str) {
62 return PyUnicode_FromString(str);
63 }
64
THPUtils_packString(const std::string & str)65 inline PyObject* THPUtils_packString(const std::string& str) {
66 return PyUnicode_FromStringAndSize(str.c_str(), str.size());
67 }
68
THPUtils_internString(const std::string & str)69 inline PyObject* THPUtils_internString(const std::string& str) {
70 return PyUnicode_InternFromString(str.c_str());
71 }
72
73 // Precondition: THPUtils_checkString(obj) must be true
THPUtils_isInterned(PyObject * obj)74 inline bool THPUtils_isInterned(PyObject* obj) {
75 return PyUnicode_CHECK_INTERNED(obj);
76 }
77
78 // Precondition: THPUtils_checkString(obj) must be true
THPUtils_internStringInPlace(PyObject ** obj)79 inline void THPUtils_internStringInPlace(PyObject** obj) {
80 PyUnicode_InternInPlace(obj);
81 }
82
83 /*
84 * Reference:
85 * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
86 *
87 * Stripped down version of PyObject_GetAttrString,
88 * avoids lookups for None, tuple, and List objects,
89 * and doesn't create a PyErr since this code ignores it.
90 *
91 * This can be much faster then PyObject_GetAttrString where
92 * exceptions are not used by caller.
93 *
94 * 'obj' is the object to search for attribute.
95 *
96 * 'name' is the attribute to search for.
97 *
98 * Returns a py::object wrapping the return value. If the attribute lookup
99 * failed the value will be NULL.
100 *
101 */
102
PyObject_FastGetAttrString(PyObject * obj,const char * name)103 inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
104 PyTypeObject* tp = Py_TYPE(obj);
105 PyObject* res = (PyObject*)nullptr;
106
107 /* Attribute referenced by (char *)name */
108 if (tp->tp_getattr != nullptr) {
109 // This is OK per https://bugs.python.org/issue39620
110 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
111 res = (*tp->tp_getattr)(obj, const_cast<char*>(name));
112 if (res == nullptr) {
113 PyErr_Clear();
114 }
115 }
116 /* Attribute referenced by (PyObject *)name */
117 else if (tp->tp_getattro != nullptr) {
118 auto w = py::reinterpret_steal<py::object>(THPUtils_internString(name));
119 if (w.ptr() == nullptr) {
120 return py::object();
121 }
122 res = (*tp->tp_getattro)(obj, w.ptr());
123 if (res == nullptr) {
124 PyErr_Clear();
125 }
126 }
127 return py::reinterpret_steal<py::object>(res);
128 }
129