xref: /aosp_15_r20/external/pytorch/torch/csrc/copy_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Types.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils.h>
6 #include <functional>
7 #include <vector>
8 
9 typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction;
10 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
11 struct THPCopyInfo {
12   PyTypeObject* srcType; // Python type of src tensor/storage
13   THPCopyFunction copy; // copy function
14   bool non_blocking; // true if copy implements an 'non_blocking' copy
15   bool broadcast; // true if the copy implements a broadcast copy
16 };
17 typedef std::vector<THPCopyInfo> THPCopyList;
18 
tryTHPCopy(const THPCopyList & v,PyObject * dst,PyObject * src,bool non_blocking,bool broadcast)19 inline bool tryTHPCopy(
20     const THPCopyList& v,
21     PyObject* dst,
22     PyObject* src,
23     bool non_blocking,
24     bool broadcast) {
25   for (auto& i : v) {
26     if (i.non_blocking == non_blocking &&
27         PyType_IsSubtype(Py_TYPE(src), i.srcType)) {
28       (i.copy)(dst, src, broadcast);
29       return true;
30     }
31   }
32   return false;
33 }
34 
THPCopy(const THPCopyList & v,PyObject * dst,PyObject * src,bool non_blocking,bool broadcast)35 inline bool THPCopy(
36     const THPCopyList& v,
37     PyObject* dst,
38     PyObject* src,
39     bool non_blocking,
40     bool broadcast) {
41   // NOLINTNEXTLINE(bugprone-branch-clone)
42   if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) {
43     return true;
44   } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) {
45     return true;
46   }
47   THPUtils_setError(
48       "copy from %s to %s isn't implemented",
49       THPUtils_typename(src),
50       THPUtils_typename(dst));
51   return false;
52 }
53