1 #pragma once
2
3 #include <torch/csrc/Types.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils.h>
6 #include <functional>
7 #include <vector>
8
9 typedef std::function<void(PyObject*, PyObject*, bool)> THPCopyFunction;
10 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
11 struct THPCopyInfo {
12 PyTypeObject* srcType; // Python type of src tensor/storage
13 THPCopyFunction copy; // copy function
14 bool non_blocking; // true if copy implements an 'non_blocking' copy
15 bool broadcast; // true if the copy implements a broadcast copy
16 };
17 typedef std::vector<THPCopyInfo> THPCopyList;
18
tryTHPCopy(const THPCopyList & v,PyObject * dst,PyObject * src,bool non_blocking,bool broadcast)19 inline bool tryTHPCopy(
20 const THPCopyList& v,
21 PyObject* dst,
22 PyObject* src,
23 bool non_blocking,
24 bool broadcast) {
25 for (auto& i : v) {
26 if (i.non_blocking == non_blocking &&
27 PyType_IsSubtype(Py_TYPE(src), i.srcType)) {
28 (i.copy)(dst, src, broadcast);
29 return true;
30 }
31 }
32 return false;
33 }
34
THPCopy(const THPCopyList & v,PyObject * dst,PyObject * src,bool non_blocking,bool broadcast)35 inline bool THPCopy(
36 const THPCopyList& v,
37 PyObject* dst,
38 PyObject* src,
39 bool non_blocking,
40 bool broadcast) {
41 // NOLINTNEXTLINE(bugprone-branch-clone)
42 if (tryTHPCopy(v, dst, src, non_blocking, broadcast)) {
43 return true;
44 } else if (non_blocking && tryTHPCopy(v, dst, src, false, broadcast)) {
45 return true;
46 }
47 THPUtils_setError(
48 "copy from %s to %s isn't implemented",
49 THPUtils_typename(src),
50 THPUtils_typename(dst));
51 return false;
52 }
53