xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_strings.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/pybind.h>
6 #include <stdexcept>
7 #include <string>
8 
9 // Utilities for handling Python strings. Note that PyString, when defined, is
10 // the same as PyBytes.
11 
12 // Returns true if obj is a bytes/str or unicode object
13 // As of Python 3.6, this does not require the GIL
THPUtils_checkString(PyObject * obj)14 inline bool THPUtils_checkString(PyObject* obj) {
15   return PyBytes_Check(obj) || PyUnicode_Check(obj);
16 }
17 
18 // Unpacks PyBytes (PyString) or PyUnicode as std::string
19 // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8.
20 // NOTE: this method requires the GIL
THPUtils_unpackString(PyObject * obj)21 inline std::string THPUtils_unpackString(PyObject* obj) {
22   if (PyBytes_Check(obj)) {
23     size_t size = PyBytes_GET_SIZE(obj);
24     return std::string(PyBytes_AS_STRING(obj), size);
25   }
26   if (PyUnicode_Check(obj)) {
27     Py_ssize_t size = 0;
28     const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
29     if (!data) {
30       throw std::runtime_error("error unpacking string as utf-8");
31     }
32     return std::string(data, (size_t)size);
33   }
34   throw std::runtime_error("unpackString: expected bytes or unicode object");
35 }
36 
37 // Unpacks PyBytes (PyString) or PyUnicode as c10::string_view
38 // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8.
39 // NOTE: If `obj` is destroyed, then the non-owning c10::string_view will
40 //   become invalid. If the string needs to be accessed at any point after
41 //   `obj` is destroyed, then the c10::string_view should be copied into
42 //   a std::string, or another owning object, and kept alive. For an example,
43 //   look at how IValue and autograd nodes handle c10::string_view arguments.
44 // NOTE: this method requires the GIL
THPUtils_unpackStringView(PyObject * obj)45 inline c10::string_view THPUtils_unpackStringView(PyObject* obj) {
46   if (PyBytes_Check(obj)) {
47     size_t size = PyBytes_GET_SIZE(obj);
48     return c10::string_view(PyBytes_AS_STRING(obj), size);
49   }
50   if (PyUnicode_Check(obj)) {
51     Py_ssize_t size = 0;
52     const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
53     if (!data) {
54       throw std::runtime_error("error unpacking string as utf-8");
55     }
56     return c10::string_view(data, (size_t)size);
57   }
58   throw std::runtime_error("unpackString: expected bytes or unicode object");
59 }
60 
THPUtils_packString(const char * str)61 inline PyObject* THPUtils_packString(const char* str) {
62   return PyUnicode_FromString(str);
63 }
64 
THPUtils_packString(const std::string & str)65 inline PyObject* THPUtils_packString(const std::string& str) {
66   return PyUnicode_FromStringAndSize(str.c_str(), str.size());
67 }
68 
THPUtils_internString(const std::string & str)69 inline PyObject* THPUtils_internString(const std::string& str) {
70   return PyUnicode_InternFromString(str.c_str());
71 }
72 
73 // Precondition: THPUtils_checkString(obj) must be true
THPUtils_isInterned(PyObject * obj)74 inline bool THPUtils_isInterned(PyObject* obj) {
75   return PyUnicode_CHECK_INTERNED(obj);
76 }
77 
78 // Precondition: THPUtils_checkString(obj) must be true
THPUtils_internStringInPlace(PyObject ** obj)79 inline void THPUtils_internStringInPlace(PyObject** obj) {
80   PyUnicode_InternInPlace(obj);
81 }
82 
83 /*
84  * Reference:
85  * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42
86  *
87  * Stripped down version of PyObject_GetAttrString,
88  * avoids lookups for None, tuple, and List objects,
89  * and doesn't create a PyErr since this code ignores it.
90  *
91  * This can be much faster then PyObject_GetAttrString where
92  * exceptions are not used by caller.
93  *
94  * 'obj' is the object to search for attribute.
95  *
96  * 'name' is the attribute to search for.
97  *
98  * Returns a py::object wrapping the return value. If the attribute lookup
99  * failed the value will be NULL.
100  *
101  */
102 
PyObject_FastGetAttrString(PyObject * obj,const char * name)103 inline py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) {
104   PyTypeObject* tp = Py_TYPE(obj);
105   PyObject* res = (PyObject*)nullptr;
106 
107   /* Attribute referenced by (char *)name */
108   if (tp->tp_getattr != nullptr) {
109     // This is OK per https://bugs.python.org/issue39620
110     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
111     res = (*tp->tp_getattr)(obj, const_cast<char*>(name));
112     if (res == nullptr) {
113       PyErr_Clear();
114     }
115   }
116   /* Attribute referenced by (PyObject *)name */
117   else if (tp->tp_getattro != nullptr) {
118     auto w = py::reinterpret_steal<py::object>(THPUtils_internString(name));
119     if (w.ptr() == nullptr) {
120       return py::object();
121     }
122     res = (*tp->tp_getattro)(obj, w.ptr());
123     if (res == nullptr) {
124       PyErr_Clear();
125     }
126   }
127   return py::reinterpret_steal<py::object>(res);
128 }
129