xref: /aosp_15_r20/external/pytorch/torch/csrc/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <fmt/core.h>
2 #include <torch/csrc/DynamicTypes.h>
3 #include <torch/csrc/THP.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/invalid_arguments.h>
7 #include <torch/csrc/utils/python_strings.h>
8 #include <torch/csrc/utils/python_symnode.h>
9 #include <torch/csrc/utils/python_tuples.h>
10 
11 #include <torch/csrc/Export.h>
12 
13 #include <algorithm>
14 #include <cstdarg>
15 #include <cstring>
16 #include <iterator>
17 #include <sstream>
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 #include <vector>
22 
THPUtils_getCallable(PyObject * arg,PyObject ** result)23 int THPUtils_getCallable(PyObject* arg, PyObject** result) {
24   if (!PyCallable_Check(arg))
25     return 0;
26   *result = arg;
27   return 1;
28 }
29 
THPUtils_checkIndex(PyObject * obj)30 bool THPUtils_checkIndex(PyObject* obj) {
31   if (PyBool_Check(obj)) {
32     return false;
33   }
34   if (THPUtils_checkLong(obj)) {
35     return true;
36   }
37   // Avoid poking __index__ early as that will immediately cause a guard
38   if (torch::is_symint(py::handle(obj))) {
39     return true;
40   }
41   torch::jit::tracer::NoWarn no_warn_guard;
42   auto index = THPObjectPtr(PyNumber_Index(obj));
43   if (!index) {
44     PyErr_Clear();
45     return false;
46   }
47   return true;
48 }
49 
THPUtils_unpackLongs(PyObject * arg)50 std::vector<int64_t> THPUtils_unpackLongs(PyObject* arg) {
51   bool tuple = PyTuple_Check(arg);
52   bool list = PyList_Check(arg);
53   if (tuple || list) {
54     // NOLINTNEXTLINE(bugprone-branch-clone)
55     const auto nDim = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
56     std::vector<int64_t> sizes(nDim);
57     for (int i = 0; i != nDim; ++i) {
58       PyObject* item =
59           tuple ? PyTuple_GET_ITEM(arg, i) : PyList_GET_ITEM(arg, i);
60       if (!THPUtils_checkLong(item)) {
61         std::ostringstream oss;
62         oss << "expected int at position " << i
63             << ", but got: " << THPUtils_typename(item);
64         throw std::runtime_error(oss.str());
65       }
66       sizes[i] = THPUtils_unpackLong(item);
67     }
68     return sizes;
69   }
70   throw std::runtime_error("Expected tuple or list");
71 }
72 
THPUtils_checkIntTuple(PyObject * arg)73 bool THPUtils_checkIntTuple(PyObject* arg) {
74   if (!PyTuple_Check(arg)) {
75     return false;
76   }
77   for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
78     if (!THPUtils_checkLong(PyTuple_GET_ITEM(arg, i))) {
79       return false;
80     }
81   }
82   return true;
83 }
84 
THPUtils_unpackIntTuple(PyObject * arg)85 std::vector<int> THPUtils_unpackIntTuple(PyObject* arg) {
86   if (!THPUtils_checkIntTuple(arg)) {
87     throw std::runtime_error("Couldn't unpack int tuple");
88   }
89   std::vector<int> values(PyTuple_GET_SIZE(arg));
90   for (Py_ssize_t i = 0; i < PyTuple_GET_SIZE(arg); ++i) {
91     values[i] = (int)THPUtils_unpackLong(PyTuple_GET_ITEM(arg, i));
92   }
93   return values;
94 }
95 
THPUtils_setError(const char * format,...)96 void THPUtils_setError(const char* format, ...) {
97   static const size_t ERROR_BUFFER_SIZE = 1000;
98   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
99   char buffer[ERROR_BUFFER_SIZE];
100   va_list fmt_args;
101 
102   va_start(fmt_args, format);
103   vsnprintf(buffer, ERROR_BUFFER_SIZE, format, fmt_args);
104   va_end(fmt_args);
105   PyErr_SetString(PyExc_RuntimeError, buffer);
106 }
107 
THPUtils_addPyMethodDefs(std::vector<PyMethodDef> & vector,PyMethodDef * methods)108 void THPUtils_addPyMethodDefs(
109     std::vector<PyMethodDef>& vector,
110     PyMethodDef* methods) {
111   if (!vector.empty()) {
112     // remove nullptr terminator
113     vector.pop_back();
114   }
115   while (true) {
116     vector.push_back(*methods);
117     if (!methods->ml_name) {
118       break;
119     }
120     methods++;
121   }
122 }
123 
classOrTypename(PyObject * obj)124 static const char* classOrTypename(PyObject* obj) {
125   if (PyType_Check(obj)) {
126     return ((PyTypeObject*)obj)->tp_name;
127   }
128   return Py_TYPE(obj)->tp_name;
129 }
130 
THPUtils_dispatchStateless(PyObject * tensor,const char * name,PyObject * args,PyObject * kwargs)131 PyObject* THPUtils_dispatchStateless(
132     PyObject* tensor,
133     const char* name,
134     PyObject* args,
135     PyObject* kwargs) {
136   THPObjectPtr methods(
137       PyObject_GetAttrString(tensor, THP_STATELESS_ATTRIBUTE_NAME));
138   if (!methods) {
139     return PyErr_Format(
140         PyExc_TypeError,
141         "Type %s doesn't implement stateless methods",
142         classOrTypename(tensor));
143   }
144   THPObjectPtr method(PyObject_GetAttrString(methods, name));
145   if (!method) {
146     return PyErr_Format(
147         PyExc_TypeError,
148         "Type %s doesn't implement stateless method %s",
149         classOrTypename(tensor),
150         name);
151   }
152   return PyObject_Call(method.get(), args, kwargs);
153 }
154 
THPUtils_invalidArguments(PyObject * given_args,PyObject * given_kwargs,const char * function_name,size_t num_options,...)155 void THPUtils_invalidArguments(
156     PyObject* given_args,
157     PyObject* given_kwargs,
158     const char* function_name,
159     size_t num_options,
160     ...) {
161   std::vector<std::string> option_strings;
162   va_list option_list;
163   va_start(option_list, num_options);
164   std::generate_n(
165       std::back_inserter(option_strings), num_options, [&option_list] {
166         return va_arg(option_list, const char*);
167       });
168   va_end(option_list);
169 
170   PyErr_SetString(
171       PyExc_TypeError,
172       torch::format_invalid_args(
173           given_args, given_kwargs, function_name, option_strings)
174           .c_str());
175 }
176 
177 template <>
free()178 void THPPointer<THPGenerator>::free() {
179   if (ptr)
180     Py_DECREF(ptr);
181 }
182 
183 template class THPPointer<THPGenerator>;
184 
185 static bool backCompatBroadcastWarn = false;
186 
setBackCompatBroadcastWarn(bool warn)187 void setBackCompatBroadcastWarn(bool warn) {
188   backCompatBroadcastWarn = warn;
189 }
190 
getBackCompatBroadcastWarn()191 bool getBackCompatBroadcastWarn() {
192   return backCompatBroadcastWarn;
193 }
194 
195 static bool backCompatKeepdimWarn = false;
196 
setBackCompatKeepdimWarn(bool warn)197 void setBackCompatKeepdimWarn(bool warn) {
198   backCompatKeepdimWarn = warn;
199 }
200 
getBackCompatKeepdimWarn()201 bool getBackCompatKeepdimWarn() {
202   return backCompatKeepdimWarn;
203 }
204 
maybeThrowBackCompatKeepdimWarn(char * func)205 bool maybeThrowBackCompatKeepdimWarn(char* func) {
206   if (getBackCompatKeepdimWarn()) {
207     std::ostringstream ss;
208     ss << "backwards compatibility: call to \"" << func
209        << "\" uses default value for keepdim which has changed default to False.  Consider passing as kwarg.",
210         PyErr_WarnEx(PyExc_UserWarning, ss.str().c_str(), 1);
211   }
212   return true;
213 }
214 
215 template <>
free()216 void THPPointer<THPStorage>::free() {
217   if (ptr)
218     Py_DECREF(ptr);
219 }
220 
storage_fill(const at::Storage & self,uint8_t value)221 void storage_fill(const at::Storage& self, uint8_t value) {
222   auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
223   auto self_t = at::empty({0}, options).set_(self);
224   self_t.fill_(value);
225 }
226 
storage_set(const at::Storage & self,ptrdiff_t idx,uint8_t value)227 void storage_set(const at::Storage& self, ptrdiff_t idx, uint8_t value) {
228   TORCH_CHECK(
229       (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
230       "out of bounds");
231   auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
232   auto self_t = at::empty({0}, options).set_(self);
233   self_t[idx].fill_(value);
234 }
235 
storage_get(const at::Storage & self,ptrdiff_t idx)236 uint8_t storage_get(const at::Storage& self, ptrdiff_t idx) {
237   TORCH_CHECK(
238       (idx >= 0) && (idx < static_cast<ptrdiff_t>(self.nbytes())),
239       "out of bounds");
240   auto options = c10::TensorOptions().device(self.device()).dtype(at::kByte);
241   auto self_t = at::empty({0}, options).set_(self);
242   return self_t[idx].item<uint8_t>();
243 }
244 
245 template class THPPointer<THPStorage>;
246 
247 namespace torch::gdb {
248 /* ~~~ misc debugging utilities ~~~
249  *
250  * torch::gdb::* functions are NOT meant to be called by general pytorch code,
251  * but only from within a gdb session. As such, utils.h does not contain any
252  * declaration for those.
253  */
254 
255 // This is a helper needed by the torch-tensor-repr gdb command.
256 // Return an human-readable representation of the given Tensor. The resulting
257 // string is stored into a malloc()ed buffer. The caller is responsible to
258 // free() it. We use malloc() instead of new[] because it's much easier to
259 // call free than delete[] from withing gdb.
260 // Currently the code for computing the repr of a tensor is written in Python,
261 // so we need to wrap the Tensor into a Python object first.
tensor_repr(at::Tensor tensor)262 char* tensor_repr(at::Tensor tensor) {
263   PyGILState_STATE gil = PyGILState_Ensure();
264   PyObject* pytensor = nullptr;
265   PyObject* repr = nullptr;
266   Py_ssize_t bufsize = 0;
267   const char* buf = nullptr;
268   char* result = nullptr;
269 
270   // NB: It's important not to move the tensor into THPVariable_Wrap,
271   // because this function is only called from our gdb macros, and
272   // we want to avoid accidentally moving out the tensor.  In principle,
273   // the Tensor signature above should induce a copy, but we've
274   // observed that sometimes gdb passes the outer Tensor address exactly as is
275   // into this function.
276   // See https://github.com/pytorch/pytorch/issues/134762
277   pytensor = THPVariable_Wrap(tensor);
278   if (!pytensor)
279     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
280     goto error;
281   repr = PyObject_Repr(pytensor);
282   if (!repr)
283     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
284     goto error;
285   buf = PyUnicode_AsUTF8AndSize(repr, &bufsize);
286   if (!buf)
287     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
288     goto error;
289   // account for the trailing \0
290   // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
291   result = static_cast<char*>(malloc(bufsize + 1));
292   if (!result) {
293     fmt::print(stderr, "cannot allocate memory for the result\n");
294     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
295     goto error;
296   }
297   std::strncpy(result, buf, bufsize);
298   result[bufsize] = '\0';
299   Py_XDECREF(pytensor);
300   Py_XDECREF(repr);
301   PyGILState_Release(gil);
302   return result;
303 
304 error:
305   fprintf(stderr, "torch::gdb::tensor_repr: unexpected error\n");
306   if (PyErr_Occurred())
307     PyErr_Print();
308   Py_XDECREF(pytensor);
309   Py_XDECREF(repr);
310   // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
311   free(result);
312   PyGILState_Release(gil);
313   return nullptr;
314 }
315 
int_array_ref_string(at::IntArrayRef sizes)316 std::string int_array_ref_string(at::IntArrayRef sizes) {
317   std::stringstream ss;
318   ss << sizes;
319   return ss.str();
320 }
321 
dispatch_keyset_string(c10::DispatchKeySet keyset)322 std::string dispatch_keyset_string(c10::DispatchKeySet keyset) {
323   std::stringstream ss;
324   ss << keyset;
325   return ss.str();
326 }
327 
328 } // namespace torch::gdb
329 
330 namespace pybind11::detail {
331 
load(handle src,bool)332 bool type_caster<at::Tensor>::load(handle src, bool) {
333   PyObject* obj = src.ptr();
334   if (THPVariable_Check(obj)) {
335     value = THPVariable_Unpack(obj);
336     return true;
337   }
338   return false;
339 }
340 
cast(const at::Tensor & src,return_value_policy,handle)341 handle type_caster<at::Tensor>::cast(
342     const at::Tensor& src,
343     return_value_policy /* policy */,
344     handle /* parent */) {
345   return handle(THPVariable_Wrap(src));
346 }
347 
load(handle src,bool)348 bool type_caster<at::IntArrayRef>::load(handle src, bool) {
349   PyObject* source = src.ptr();
350   auto tuple = PyTuple_Check(source);
351   if (tuple || PyList_Check(source)) {
352     // NOLINTNEXTLINE(bugprone-branch-clone)
353     const auto size =
354         tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
355     v_value.resize(size);
356     for (const auto idx : c10::irange(size)) {
357       PyObject* obj =
358           tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
359       if (THPVariable_Check(obj)) {
360         v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
361       } else if (PyLong_Check(obj)) {
362         // use THPUtils_unpackLong after it is safe to include
363         // python_numbers.h
364         v_value[idx] = THPUtils_unpackLong(obj);
365       } else {
366         return false;
367       }
368     }
369     value = v_value;
370     return true;
371   }
372   return false;
373 }
cast(at::IntArrayRef src,return_value_policy,handle)374 handle type_caster<at::IntArrayRef>::cast(
375     at::IntArrayRef src,
376     return_value_policy /* policy */,
377     handle /* parent */) {
378   return handle(THPUtils_packInt64Array(src.size(), src.data()));
379 }
380 
load(handle src,bool)381 bool type_caster<at::SymIntArrayRef>::load(handle src, bool) {
382   PyObject* source = src.ptr();
383 
384   auto tuple = PyTuple_Check(source);
385   if (tuple || PyList_Check(source)) {
386     // NOLINTNEXTLINE(bugprone-branch-clone)
387     const auto size =
388         tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source);
389     v_value.resize(size);
390     for (const auto idx : c10::irange(size)) {
391       PyObject* obj =
392           tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx);
393 
394       if (THPVariable_Check(obj)) {
395         // TODO: this is for consistency with IntArrayRef but arguably
396         // we shouldn't really allow this on pybind11 casters
397         v_value[idx] = THPVariable_Unpack(obj).item<int64_t>();
398       } else if (torch::is_symint(py::handle(obj))) {
399         v_value[idx] = py::handle(obj).cast<c10::SymInt>();
400       } else if (PyLong_Check(obj)) {
401         v_value[idx] = c10::SymInt(THPUtils_unpackIndex(obj));
402       } else {
403         return false;
404       }
405     }
406     value = v_value;
407     return true;
408   }
409   return false;
410 }
cast(at::SymIntArrayRef src,return_value_policy,handle)411 handle type_caster<at::SymIntArrayRef>::cast(
412     at::SymIntArrayRef src,
413     return_value_policy /* policy */,
414     handle /* parent */) {
415   py::list t(src.size());
416   for (const auto i : c10::irange(src.size())) {
417     t[i] = py::cast(src[i]);
418   }
419   return t.release();
420 }
421 
load(handle src,bool)422 bool type_caster<at::ArrayRef<c10::SymNode>>::load(handle src, bool) {
423   TORCH_INTERNAL_ASSERT(0, "NYI");
424 }
cast(at::ArrayRef<c10::SymNode> src,return_value_policy,handle)425 handle type_caster<at::ArrayRef<c10::SymNode>>::cast(
426     at::ArrayRef<c10::SymNode> src,
427     return_value_policy /* policy */,
428     handle /* parent */) {
429   py::list t(src.size());
430   for (const auto i : c10::irange(src.size())) {
431     // TODO: this is terrible but I don't know how to override when
432     // the SymNode is also explicitly cast by py::cast
433     auto* py_node = dynamic_cast<torch::impl::PythonSymNodeImpl*>(src[i].get());
434     if (py_node) {
435       // Return the Python directly (unwrap)
436       t[i] = py_node->getPyObj();
437     } else {
438       t[i] = py::cast(src[i]);
439     }
440   }
441   return t.release();
442 }
443 
444 } // namespace pybind11::detail
445