xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/utils/wrap_outputs.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Wrap tensor operation outputs as PyObject*
4 
5 #include <ATen/ScalarOps.h>
6 #include <ATen/core/Tensor.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/python_headers.h>
9 #include <initializer_list>
10 #include <tuple>
11 
12 #include <torch/csrc/Dtype.h>
13 #include <torch/csrc/DynamicTypes.h>
14 #include <torch/csrc/Layout.h>
15 #include <torch/csrc/QScheme.h>
16 #include <torch/csrc/autograd/python_variable.h>
17 #include <torch/csrc/autograd/variable.h>
18 #include <torch/csrc/utils/python_numbers.h>
19 #include <torch/csrc/utils/tensor_qschemes.h>
20 
21 namespace torch::autograd::utils {
22 
wrap(bool value)23 inline PyObject* wrap(bool value) {
24   if (value) {
25     Py_RETURN_TRUE;
26   } else {
27     Py_RETURN_FALSE;
28   }
29 }
30 
wrap(c10::DeviceIndex value)31 inline PyObject* wrap(c10::DeviceIndex value) {
32   return THPUtils_packDeviceIndex(value);
33 }
34 
wrap(int64_t value)35 inline PyObject* wrap(int64_t value) {
36   return THPUtils_packInt64(value);
37 }
38 
wrap(double value)39 inline PyObject* wrap(double value) {
40   return PyFloat_FromDouble(value);
41 }
42 
wrap(c10::complex<double> value)43 inline PyObject* wrap(c10::complex<double> value) {
44   // I could probably also use FromComplex with a reinterpret cast,
45   // but... eh.
46   return PyComplex_FromDoubles(value.real(), value.imag());
47 }
48 
wrap(void * value)49 inline PyObject* wrap(void* value) {
50   return THPUtils_packInt64(reinterpret_cast<intptr_t>(value));
51 }
52 
wrap(THPDtype * dtype)53 inline PyObject* wrap(THPDtype* dtype) {
54   return Py_NewRef(dtype);
55 }
56 
wrap(at::ScalarType scalarType)57 inline PyObject* wrap(at::ScalarType scalarType) {
58   return Py_NewRef(getTHPDtype(scalarType));
59 }
60 
wrap(THPLayout * layout)61 inline PyObject* wrap(THPLayout* layout) {
62   return Py_NewRef(layout);
63 }
64 
wrap(at::Layout layout)65 inline PyObject* wrap(at::Layout layout) {
66   return Py_NewRef(getTHPLayout(layout));
67 }
68 
wrap(at::Tensor tensor)69 inline PyObject* wrap(at::Tensor tensor) {
70   return THPVariable_Wrap(Variable(std::move(tensor)));
71 }
72 
wrap(const at::Scalar & scalar)73 inline PyObject* wrap(const at::Scalar& scalar) {
74   return wrap(scalar_to_tensor(scalar));
75 }
76 
wrap(at::QScheme qscheme)77 inline PyObject* wrap(at::QScheme qscheme) {
78   auto* thp_qscheme = torch::utils::getTHPQScheme(qscheme);
79   Py_INCREF(thp_qscheme);
80   return thp_qscheme;
81 }
82 
wrap(at::TensorList tl)83 inline PyObject* wrap(at::TensorList tl) {
84   auto r = THPObjectPtr{PyTuple_New(tl.size())};
85   if (!r)
86     throw python_error();
87   for (const auto i : c10::irange(tl.size())) {
88     PyTuple_SET_ITEM(r.get(), i, wrap(tl[i]));
89   }
90   return r.release();
91 }
92 
wrap(at::IntArrayRef list)93 inline PyObject* wrap(at::IntArrayRef list) {
94   auto r = THPObjectPtr{PyTuple_New(list.size())};
95   if (!r)
96     throw python_error();
97   for (const auto i : c10::irange(list.size())) {
98     PyTuple_SET_ITEM(r.get(), i, wrap(list[i]));
99   }
100   return r.release();
101 }
102 
wrap(at::Stream stream)103 inline PyObject* wrap(at::Stream stream) {
104   return THPStream_Wrap(stream);
105 }
106 
107 namespace detail {
108 template <typename F, typename Tuple, size_t... Is>
apply_with_idx_impl(const F & f,Tuple & t,std::index_sequence<Is...>)109 void apply_with_idx_impl(
110     const F& f,
111     Tuple& t,
112     std::index_sequence<Is...> /*indices*/) {
113   (void)std::initializer_list<int>{(f(std::get<Is>(t), Is), 0)...};
114 }
115 
116 // For tuple(a, b, c), calls f(a, 0), f(b, 1), f(c, 2)
117 template <typename F, typename... Ts>
apply_with_idx(const F & f,std::tuple<Ts...> & t)118 void apply_with_idx(const F& f, std::tuple<Ts...>& t) {
119   apply_with_idx_impl(f, t, std::index_sequence_for<Ts...>{});
120 }
121 } // namespace detail
122 
123 template <typename... Ts>
wrap(std::tuple<Ts...> values)124 PyObject* wrap(std::tuple<Ts...> values) {
125   auto r = THPObjectPtr{PyTuple_New(sizeof...(Ts))};
126   if (!r)
127     throw python_error();
128   detail::apply_with_idx(
129       [&](auto& value, size_t idx) {
130         PyTuple_SET_ITEM(r.get(), idx, wrap(std::move(value)));
131       },
132       values);
133   return r.release();
134 }
135 
136 template <typename... Ts>
wrap(PyTypeObject * type,std::tuple<Ts...> values)137 PyObject* wrap(PyTypeObject* type, std::tuple<Ts...> values) {
138   auto r = THPObjectPtr{PyStructSequence_New(type)};
139   if (!r)
140     throw python_error();
141   detail::apply_with_idx(
142       [&](auto& value, size_t idx) {
143         PyStructSequence_SET_ITEM(r.get(), idx, wrap(std::move(value)));
144       },
145       values);
146   return r.release();
147 }
148 
149 } // namespace torch::autograd::utils
150