1 #pragma once
2
3 // Wrap tensor operation outputs as PyObject*
4
5 #include <ATen/ScalarOps.h>
6 #include <ATen/core/Tensor.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/python_headers.h>
9 #include <initializer_list>
10 #include <tuple>
11
12 #include <torch/csrc/Dtype.h>
13 #include <torch/csrc/DynamicTypes.h>
14 #include <torch/csrc/Layout.h>
15 #include <torch/csrc/QScheme.h>
16 #include <torch/csrc/autograd/python_variable.h>
17 #include <torch/csrc/autograd/variable.h>
18 #include <torch/csrc/utils/python_numbers.h>
19 #include <torch/csrc/utils/tensor_qschemes.h>
20
21 namespace torch::autograd::utils {
22
wrap(bool value)23 inline PyObject* wrap(bool value) {
24 if (value) {
25 Py_RETURN_TRUE;
26 } else {
27 Py_RETURN_FALSE;
28 }
29 }
30
wrap(c10::DeviceIndex value)31 inline PyObject* wrap(c10::DeviceIndex value) {
32 return THPUtils_packDeviceIndex(value);
33 }
34
wrap(int64_t value)35 inline PyObject* wrap(int64_t value) {
36 return THPUtils_packInt64(value);
37 }
38
wrap(double value)39 inline PyObject* wrap(double value) {
40 return PyFloat_FromDouble(value);
41 }
42
wrap(c10::complex<double> value)43 inline PyObject* wrap(c10::complex<double> value) {
44 // I could probably also use FromComplex with a reinterpret cast,
45 // but... eh.
46 return PyComplex_FromDoubles(value.real(), value.imag());
47 }
48
wrap(void * value)49 inline PyObject* wrap(void* value) {
50 return THPUtils_packInt64(reinterpret_cast<intptr_t>(value));
51 }
52
wrap(THPDtype * dtype)53 inline PyObject* wrap(THPDtype* dtype) {
54 return Py_NewRef(dtype);
55 }
56
wrap(at::ScalarType scalarType)57 inline PyObject* wrap(at::ScalarType scalarType) {
58 return Py_NewRef(getTHPDtype(scalarType));
59 }
60
wrap(THPLayout * layout)61 inline PyObject* wrap(THPLayout* layout) {
62 return Py_NewRef(layout);
63 }
64
wrap(at::Layout layout)65 inline PyObject* wrap(at::Layout layout) {
66 return Py_NewRef(getTHPLayout(layout));
67 }
68
wrap(at::Tensor tensor)69 inline PyObject* wrap(at::Tensor tensor) {
70 return THPVariable_Wrap(Variable(std::move(tensor)));
71 }
72
wrap(const at::Scalar & scalar)73 inline PyObject* wrap(const at::Scalar& scalar) {
74 return wrap(scalar_to_tensor(scalar));
75 }
76
wrap(at::QScheme qscheme)77 inline PyObject* wrap(at::QScheme qscheme) {
78 auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme);
79 Py_INCREF(thp_qscheme);
80 return thp_qscheme;
81 }
82
wrap(at::TensorList tl)83 inline PyObject* wrap(at::TensorList tl) {
84 auto r = THPObjectPtr{PyTuple_New(tl.size())};
85 if (!r)
86 throw python_error();
87 for (const auto i : c10::irange(tl.size())) {
88 PyTuple_SET_ITEM(r.get(), i, wrap(tl[i]));
89 }
90 return r.release();
91 }
92
wrap(at::IntArrayRef list)93 inline PyObject* wrap(at::IntArrayRef list) {
94 auto r = THPObjectPtr{PyTuple_New(list.size())};
95 if (!r)
96 throw python_error();
97 for (const auto i : c10::irange(list.size())) {
98 PyTuple_SET_ITEM(r.get(), i, wrap(list[i]));
99 }
100 return r.release();
101 }
102
wrap(at::Stream stream)103 inline PyObject* wrap(at::Stream stream) {
104 return THPStream_Wrap(stream);
105 }
106
107 namespace detail {
108 template <typename F, typename Tuple, size_t... Is>
apply_with_idx_impl(const F & f,Tuple & t,std::index_sequence<Is...>)109 void apply_with_idx_impl(
110 const F& f,
111 Tuple& t,
112 std::index_sequence<Is...> /*indices*/) {
113 (void)std::initializer_list<int>{(f(std::get<Is>(t), Is), 0)...};
114 }
115
116 // For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2)
117 template <typename F, typename... Ts>
apply_with_idx(const F & f,std::tuple<Ts...> & t)118 void apply_with_idx(const F& f, std::tuple<Ts...>& t) {
119 apply_with_idx_impl(f, t, std::index_sequence_for<Ts...>{});
120 }
121 } // namespace detail
122
123 template <typename... Ts>
wrap(std::tuple<Ts...> values)124 PyObject* wrap(std::tuple<Ts...> values) {
125 auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))};
126 if (!r)
127 throw python_error();
128 detail::apply_with_idx(
129 [&](auto& value, size_t idx) {
130 PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value)));
131 },
132 values);
133 return r.release();
134 }
135
136 template <typename... Ts>
wrap(PyTypeObject * type,std::tuple<Ts...> values)137 PyObject* wrap(PyTypeObject* type, std::tuple<Ts...> values) {
138 auto r = THPObjectPtr{PyStructSequence_New(type)};
139 if (!r)
140 throw python_error();
141 detail::apply_with_idx(
142 [&](auto& value, size_t idx) {
143 PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value)));
144 },
145 values);
146 return r.release();
147 }
148
149 } // namespace torch::autograd::utils
150