xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_arg_flatten.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/python/python_arg_flatten.h>
3 #include <torch/csrc/utils/python_strings.h>
4 #include <torch/csrc/utils/six.h>
5 
6 #include <torch/csrc/autograd/grad_mode.h>
7 
8 namespace torch::jit::python {
9 
10 using namespace torch::autograd;
11 using namespace at;
12 
13 // Alphabet used to describe structure of inputs/outputs (D for desc)
14 namespace D {
15 static constexpr char DictOpen = '<';
16 static constexpr char DictClose = '>';
17 static constexpr char ListOpen = '[';
18 static constexpr char ListClose = ']';
19 static constexpr char TupleOpen = '(';
20 static constexpr char TupleClose = ')';
21 static constexpr char Variable = 'v';
22 static constexpr char Bool = 'b';
23 static constexpr char Long = 'l';
24 static constexpr char Double = 'd';
25 static constexpr char String = 's';
26 static constexpr char NoneType = 'n';
27 } // namespace D
28 
29 namespace {
30 
PyNone_Check(PyObject * o)31 inline bool PyNone_Check(PyObject* o) {
32   return o == Py_None;
33 }
34 
35 template <typename T>
cast_handle_sequence(std::vector<py::handle> objs)36 py::object cast_handle_sequence(std::vector<py::handle> objs) {
37   auto num_objs = objs.size();
38   T sequence{num_objs};
39   for (const auto i : c10::irange(num_objs)) {
40     sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
41   }
42   return sequence;
43 }
44 
flatten_rec(PyObject * obj,ParsedArgs & args)45 void flatten_rec(PyObject* obj, ParsedArgs& args) {
46   auto& structure = args.desc.structure;
47   if (six::isTuple(obj)) {
48     structure.push_back(D::TupleOpen);
49     for (auto item : py::reinterpret_borrow<py::tuple>(obj))
50       flatten_rec(item.ptr(), args);
51     structure.push_back(D::TupleClose);
52   } else if (PyList_Check(obj)) {
53     structure.push_back(D::ListOpen);
54     for (auto item : py::reinterpret_borrow<py::list>(obj))
55       flatten_rec(item.ptr(), args);
56     structure.push_back(D::ListClose);
57   } else if (PyDict_Check(obj)) {
58     auto* dict_items = PyDict_Items(obj);
59     structure.push_back(D::DictOpen);
60     for (auto item : py::reinterpret_borrow<py::list>(dict_items)) {
61       flatten_rec(item.ptr(), args);
62     }
63     structure.push_back(D::DictClose);
64     Py_DECREF(dict_items);
65   } else if (THPUtils_checkString(obj)) {
66     string str = THPUtils_unpackString(obj);
67     args.desc.strings.emplace_back(str);
68     args.desc.structure.push_back(D::String);
69   } else if (THPVariable_Check(obj)) {
70     auto& var = THPVariable_Unpack(obj);
71     args.vars.push_back(var);
72     args.desc.metadata.emplace_back(var);
73     args.desc.structure.push_back(D::Variable);
74   } else if (PyNone_Check(obj)) {
75     args.desc.structure.push_back(D::NoneType);
76   } else if (PyBool_Check(obj)) { // Wrap bools in Bool tensors
77     at::Tensor var = scalar_to_tensor(at::Scalar(THPUtils_unpackBool(obj)));
78     args.vars.push_back(var);
79     args.desc.metadata.emplace_back(var);
80     args.desc.structure.push_back(D::Bool);
81   } else if (PyLong_Check(obj)) { // Wrap longs in Long tensors
82     at::Tensor var = scalar_to_tensor(
83         at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(obj))));
84     args.vars.push_back(var);
85     args.desc.metadata.emplace_back(var);
86     args.desc.structure.push_back(D::Long);
87   } else if (PyFloat_Check(obj)) { // Wrap floats in Double tensors
88     at::Tensor var = scalar_to_tensor(THPUtils_unpackDouble(obj));
89     args.vars.push_back(var);
90     args.desc.metadata.emplace_back(var);
91     args.desc.structure.push_back(D::Double);
92   } else {
93     std::string msg =
94         "Only tuples, lists and Variables are supported as JIT inputs/outputs. "
95         "Dictionaries and strings are also accepted, but their usage is not "
96         "recommended. Here, received an input of unsupported type: ";
97     msg += THPUtils_typename(obj);
98     throw std::runtime_error(msg);
99   }
100 }
101 
102 } // anonymous namespace
103 
flatten(py::handle obj)104 ParsedArgs flatten(py::handle obj) {
105   ParsedArgs args;
106   args.desc.grad_enabled = autograd::GradMode::is_enabled();
107   flatten_rec(obj.ptr(), args);
108   return args;
109 }
110 
111 namespace {
112 
113 template <typename T>
cast_sequence(std::vector<py::object> objs)114 py::object cast_sequence(std::vector<py::object> objs) {
115   auto num_objs = objs.size();
116   T sequence{num_objs};
117   for (const auto i : c10::irange(num_objs)) {
118     sequence[i] = std::move(objs[i]);
119   }
120   return std::move(sequence);
121 }
122 
cast_dict(std::vector<py::object> objs)123 py::object cast_dict(std::vector<py::object> objs) {
124   auto num_objs = objs.size();
125   py::dict sequence = {};
126   for (const auto i : c10::irange(num_objs)) {
127     py::tuple obj = py::reinterpret_borrow<py::tuple>(objs[i]);
128     sequence[obj[0]] = obj[1];
129   }
130   return std::move(sequence);
131 }
132 
unflatten_rec(ArrayRef<Variable>::iterator & var_it,ArrayRef<Variable>::iterator & var_it_end,std::string::const_iterator & desc_it,std::vector<string>::const_iterator & str_it,std::vector<string>::const_iterator & str_it_end)133 py::object unflatten_rec(
134     ArrayRef<Variable>::iterator& var_it,
135     ArrayRef<Variable>::iterator& var_it_end,
136     std::string::const_iterator& desc_it,
137     std::vector<string>::const_iterator& str_it,
138     std::vector<string>::const_iterator& str_it_end) {
139   char type = *desc_it++;
140   if (type == D::TupleOpen) {
141     std::vector<py::object> objs;
142     while (*desc_it != D::TupleClose)
143       objs.push_back(
144           unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
145     ++desc_it;
146     return cast_sequence<py::tuple>(objs);
147   } else if (type == D::ListOpen) {
148     std::vector<py::object> objs;
149     while (*desc_it != D::ListClose)
150       objs.push_back(
151           unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
152     ++desc_it;
153     return cast_sequence<py::list>(objs);
154   } else if (type == D::DictOpen) {
155     std::vector<py::object> objs;
156     while (*desc_it != D::DictClose) {
157       objs.push_back(
158           unflatten_rec(var_it, var_it_end, desc_it, str_it, str_it_end));
159     }
160     ++desc_it;
161     return cast_dict(objs);
162   } else if (type == D::String) {
163     if (str_it == str_it_end)
164       throw std::runtime_error("Not enough Variables given to unflatten");
165     auto str = *str_it++;
166     return py::reinterpret_borrow<py::object>(THPUtils_packString(str));
167   } else if (type == D::NoneType) {
168     return py::reinterpret_borrow<py::object>(py::none());
169   } else {
170     // if (type == D::Long || type == D::Double || type == D::Bool ||
171     // D::Variable) unwrap variables (D::Variable), or unwrap primitive types
172     // (Long, Double, Bool) as variables for tracer.
173     if (var_it == var_it_end)
174       throw std::runtime_error("Not enough Variables given to unflatten");
175     auto var = *var_it++;
176     return py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
177   }
178 }
179 
180 } // anonymous namespace
181 
unflatten(ArrayRef<Variable> vars,const IODescriptor & desc)182 PyObject* unflatten(ArrayRef<Variable> vars, const IODescriptor& desc) {
183   // NB: We don't do correctness checking on descriptor.
184   // It has to be a correct bytes object produced by unflatten.
185   auto vars_it = vars.begin();
186   auto vars_it_end = vars.end();
187   auto desc_it = desc.structure.begin();
188   std::vector<std::string>::const_iterator str_it = desc.strings.begin();
189   std::vector<std::string>::const_iterator str_end = desc.strings.end();
190   auto output = unflatten_rec(vars_it, vars_it_end, desc_it, str_it, str_end);
191   if (vars_it != vars_it_end)
192     throw std::runtime_error("Too many Variables given to unflatten");
193   return output.release().ptr();
194 }
195 
196 } // namespace torch::jit::python
197