1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <pybind11/pybind11.h>
10 #include <pybind11/stl.h>
11 #include <memory>
12 #include <stack>
13
14 #include "executorch/extension/pytree/pytree.h"
15
16 namespace py = pybind11;
17
18 namespace executorch {
19 namespace extension {
20 namespace pytree {
21
22 namespace {
23
24 struct PyAux {
25 py::object custom_type_context;
26 };
27 using PyTreeSpec = TreeSpec<PyAux>;
28
29 class PyTypeRegistry {
30 public:
31 struct PyTypeReg {
PyTypeRegexecutorch::extension::pytree::__anon59fa55ac0111::PyTypeRegistry::PyTypeReg32 explicit PyTypeReg(Kind k) : kind(k) {}
33
34 Kind kind;
35
36 // for custom types
37 py::object type;
38 // function type: object -> (children, spec_data)
39 py::function flatten;
40 // function type: (children, spec_data) -> object
41 py::function unflatten;
42 };
43
get_by_str(const std::string & pytype)44 static const PyTypeReg* get_by_str(const std::string& pytype) {
45 auto* registry = instance();
46 auto it = registry->regs_.find(pytype);
47 return it == registry->regs_.end() ? nullptr : it->second.get();
48 }
49
get_by_type(py::handle pytype)50 static const PyTypeReg* get_by_type(py::handle pytype) {
51 return get_by_str(py::str(pytype));
52 }
53
register_custom_type(py::object type,py::function flatten,py::function unflatten)54 static void register_custom_type(
55 py::object type,
56 py::function flatten,
57 py::function unflatten) {
58 auto* registry = instance();
59 auto reg = std::make_unique<PyTypeReg>(Kind::Custom);
60 reg->type = type;
61 reg->flatten = std::move(flatten);
62 reg->unflatten = std::move(unflatten);
63 std::string pytype_str = py::str(type);
64 auto it = registry->regs_.emplace(pytype_str, std::move(reg));
65 if (!it.second) {
66 assert(false);
67 }
68 }
69
70 private:
instance()71 static PyTypeRegistry* instance() {
72 static auto* registry_instance = []() -> PyTypeRegistry* {
73 auto* registry = new PyTypeRegistry;
74
75 auto add_pytype_reg = [&](const std::string& pytype, Kind kind) {
76 registry->regs_.emplace(pytype, std::make_unique<PyTypeReg>(kind));
77 };
78
79 add_pytype_reg("<class 'tuple'>", Kind::Tuple);
80 add_pytype_reg("<class 'list'>", Kind::List);
81 add_pytype_reg("<class 'dict'>", Kind::Dict);
82
83 return registry;
84 }();
85
86 return registry_instance;
87 }
88 std::unordered_map<std::string, std::unique_ptr<PyTypeReg>> regs_;
89 };
90
91 class PyTree {
92 PyTreeSpec spec_;
93
flatten_internal(py::handle x,std::vector<py::object> & leaves,PyTreeSpec & s)94 static void flatten_internal(
95 py::handle x,
96 std::vector<py::object>& leaves,
97 PyTreeSpec& s) {
98 const auto* reg = PyTypeRegistry::get_by_type(x.get_type());
99 const auto kind = [®, &x]() {
100 if (reg) {
101 return reg->kind;
102 }
103 if (py::isinstance<py::tuple>(x) && py::hasattr(x, "_fields")) {
104 return Kind::NamedTuple;
105 }
106 return Kind::Leaf;
107 }();
108 switch (kind) {
109 case Kind::List: {
110 const size_t n = PyList_GET_SIZE(x.ptr());
111 s = PyTreeSpec(Kind::List, n);
112 for (size_t i = 0; i < n; ++i) {
113 flatten_internal(PyList_GET_ITEM(x.ptr(), i), leaves, s[i]);
114 }
115 break;
116 }
117 case Kind::Tuple: {
118 const size_t n = PyTuple_GET_SIZE(x.ptr());
119 s = PyTreeSpec(Kind::Tuple, n);
120 for (size_t i = 0; i < n; ++i) {
121 flatten_internal(PyTuple_GET_ITEM(x.ptr(), i), leaves, s[i]);
122 }
123 break;
124 }
125 case Kind::NamedTuple: {
126 py::tuple tuple = py::reinterpret_borrow<py::tuple>(x);
127 const size_t n = tuple.size();
128 s = PyTreeSpec(Kind::NamedTuple, n);
129 size_t i = 0;
130 for (py::handle entry : tuple) {
131 flatten_internal(entry, leaves, s[i++]);
132 }
133 break;
134 }
135 case Kind::Dict: {
136 py::dict dict = py::reinterpret_borrow<py::dict>(x);
137 py::list keys =
138 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
139 const auto n = PyList_GET_SIZE(keys.ptr());
140 s = PyTreeSpec(Kind::Dict, n);
141 size_t i = 0;
142 for (py::handle key : keys) {
143 if (py::isinstance<py::str>(key)) {
144 s.key(i) = py::cast<std::string>(key);
145 } else if (py::isinstance<py::int_>(key)) {
146 s.key(i) = py::cast<int32_t>(key);
147 } else {
148 pytree_assert(false);
149 }
150
151 flatten_internal(dict[key], leaves, s[i]);
152 i++;
153 }
154 break;
155 }
156 case Kind::Custom: {
157 py::tuple out = py::cast<py::tuple>(reg->flatten(x));
158 if (out.size() != 2) {
159 assert(false);
160 }
161 py::list children = py::cast<py::list>(out[0]);
162 const size_t n = children.size();
163 s = PyTreeSpec(Kind::Custom, n);
164 s.handle->custom_type = py::str(x.get_type());
165 s.handle->custom_type_context = out[1];
166 size_t i = 0;
167 for (py::handle pychild : children) {
168 flatten_internal(pychild, leaves, s[i++]);
169 }
170 break;
171 }
172 case Kind::Leaf: {
173 s = PyTreeSpec(Kind::Leaf);
174 leaves.push_back(py::reinterpret_borrow<py::object>(x));
175 break;
176 }
177 case Kind::None:
178 pytree_assert(false);
179 }
180 }
181
182 template <typename T>
unflatten_internal(const PyTreeSpec & spec,T && leaves_it) const183 py::object unflatten_internal(const PyTreeSpec& spec, T&& leaves_it) const {
184 switch (spec.kind()) {
185 case Kind::NamedTuple:
186 case Kind::Tuple: {
187 const size_t size = spec.size();
188 py::tuple tuple(size);
189 for (size_t i = 0; i < size; ++i) {
190 tuple[i] = unflatten_internal(spec[i], leaves_it);
191 }
192 return std::move(tuple);
193 }
194 case Kind::List: {
195 const size_t size = spec.size();
196 py::list list(size);
197 for (size_t i = 0; i < size; ++i) {
198 list[i] = unflatten_internal(spec[i], leaves_it);
199 }
200 return std::move(list);
201 }
202 case Kind::Custom: {
203 const auto& pytype_str = spec.handle->custom_type;
204 const auto* reg = PyTypeRegistry::get_by_str(pytype_str);
205 const size_t size = spec.size();
206 py::list list(size);
207 for (size_t i = 0; i < size; ++i) {
208 list[i] = unflatten_internal(spec[i], leaves_it);
209 }
210 py::object o = reg->unflatten(list, spec.handle->custom_type_context);
211 return o;
212 }
213 case Kind::Dict: {
214 const size_t size = spec.size();
215 py::dict dict;
216 for (size_t i = 0; i < size; ++i) {
217 auto& key = spec.key(i);
218 auto py_key = [&key]() -> py::handle {
219 switch (key.kind()) {
220 case Key::Kind::Int:
221 return py::cast(key.as_int()).release();
222 case Key::Kind::Str:
223 return py::cast(key.as_str()).release();
224 case Key::Kind::None:
225 pytree_assert(false);
226 }
227 pytree_assert(false);
228 return py::none();
229 }();
230 dict[py_key] = unflatten_internal(spec[i], leaves_it);
231 }
232 return std::move(dict);
233 }
234 case Kind::Leaf: {
235 py::object o =
236 py::reinterpret_borrow<py::object>(*std::forward<T>(leaves_it));
237 leaves_it++;
238 return o;
239 }
240 case Kind::None: {
241 return py::none();
242 }
243 }
244 pytree_assert(false);
245 }
246
247 public:
PyTree(PyTreeSpec spec)248 explicit PyTree(PyTreeSpec spec) : spec_(std::move(spec)) {}
249
spec() const250 const PyTreeSpec& spec() const {
251 return spec_;
252 }
253
py_from_str(std::string spec)254 static PyTree py_from_str(std::string spec) {
255 return PyTree(from_str<PyAux>(spec));
256 }
257
py_to_str() const258 StrTreeSpec py_to_str() const {
259 return to_str(spec_);
260 }
261
262 static std::pair<std::vector<py::object>, std::unique_ptr<PyTree>>
tree_flatten(py::handle x)263 tree_flatten(py::handle x) {
264 std::vector<py::object> leaves{};
265 PyTreeSpec spec{};
266 flatten_internal(x, leaves, spec);
267 refresh_leaves_num(spec);
268 return {std::move(leaves), std::make_unique<PyTree>(std::move(spec))};
269 }
270
tree_unflatten(py::iterable leaves,py::object o)271 static py::object tree_unflatten(py::iterable leaves, py::object o) {
272 return o.cast<PyTree*>()->tree_unflatten(leaves);
273 }
274
275 template <typename T>
tree_unflatten(T leaves) const276 py::object tree_unflatten(T leaves) const {
277 return unflatten_internal(spec_, leaves.begin());
278 }
279
operator ==(const PyTree & rhs)280 bool operator==(const PyTree& rhs) {
281 return spec_ == rhs.spec_;
282 }
283
leaves_num() const284 size_t leaves_num() const {
285 return refresh_leaves_num(spec_);
286 }
287 };
288
tree_flatten(py::handle x)289 inline std::pair<std::vector<py::object>, std::unique_ptr<PyTree>> tree_flatten(
290 py::handle x) {
291 return PyTree::tree_flatten(x);
292 }
293
tree_unflatten(py::iterable leaves,py::object o)294 inline py::object tree_unflatten(py::iterable leaves, py::object o) {
295 return PyTree::tree_unflatten(leaves, o);
296 }
297
tree_map(py::function & fn,py::handle x)298 static py::object tree_map(py::function& fn, py::handle x) {
299 auto p = tree_flatten(x);
300 const auto& leaves = p.first;
301 const auto& pytree = p.second;
302 std::vector<py::handle> vec;
303 for (const py::handle& h : leaves) {
304 vec.push_back(fn(h));
305 }
306 return pytree->tree_unflatten(vec);
307 }
308
py_from_str(std::string spec)309 static std::unique_ptr<PyTree> py_from_str(std::string spec) {
310 return std::make_unique<PyTree>(from_str<PyAux>(spec));
311 }
312
broadcast_to_and_flatten(py::object x,py::object py_tree_spec)313 static py::object broadcast_to_and_flatten(
314 py::object x,
315 py::object py_tree_spec) {
316 auto p = tree_flatten(x);
317 const auto& x_leaves = p.first;
318 const auto& x_spec = p.second->spec();
319
320 PyTree* tree_spec = py_tree_spec.cast<PyTree*>();
321
322 py::list ret;
323 struct StackItem {
324 const PyTreeSpec* tree_spec_node;
325 const PyTreeSpec* x_spec_node;
326 const size_t x_leaves_offset;
327 };
328 std::stack<StackItem> stack;
329 stack.push({&tree_spec->spec(), &x_spec, 0u});
330 while (!stack.empty()) {
331 const auto top = stack.top();
332 stack.pop();
333 if (top.x_spec_node->isLeaf()) {
334 for (size_t i = 0; i < top.tree_spec_node->leaves_num(); ++i) {
335 ret.append(x_leaves[top.x_leaves_offset]);
336 }
337 } else {
338 const auto kind = top.tree_spec_node->kind();
339 if (kind != top.x_spec_node->kind()) {
340 return py::none();
341 }
342 pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343 const size_t child_num = top.tree_spec_node->size();
344 if (child_num != top.x_spec_node->size()) {
345 return py::none();
346 }
347 pytree_assert(child_num == top.x_spec_node->size());
348
349 size_t x_leaves_offset =
350 top.x_leaves_offset + top.x_spec_node->leaves_num();
351 auto fn_i = [&](size_t i) {
352 x_leaves_offset -= (*top.x_spec_node)[i].leaves_num();
353 stack.push(
354 {&(*top.tree_spec_node)[i],
355 &(*top.x_spec_node)[i],
356 x_leaves_offset});
357 };
358 if (Kind::Dict == kind) {
359 for (size_t i = child_num - 1; i < child_num; --i) {
360 if (top.tree_spec_node->key(i) != top.x_spec_node->key(i)) {
361 return py::none();
362 }
363 fn_i(i);
364 }
365 } else {
366 for (size_t i = child_num - 1; i < child_num; --i) {
367 fn_i(i);
368 }
369 }
370 }
371 }
372 return std::move(ret);
373 }
374
375 } // namespace
376
PYBIND11_MODULE(pybindings,m)377 PYBIND11_MODULE(pybindings, m) {
378 m.def("tree_flatten", &tree_flatten, py::arg("tree"));
379 m.def("tree_unflatten", &tree_unflatten, py::arg("leaves"), py::arg("tree"));
380 m.def("tree_map", &tree_map);
381 m.def("from_str", &py_from_str);
382 m.def("broadcast_to_and_flatten", &broadcast_to_and_flatten);
383 m.def("register_custom", &PyTypeRegistry::register_custom_type);
384
385 py::class_<PyTree>(m, "TreeSpec")
386 .def("from_str", &PyTree::py_from_str)
387 .def(
388 "tree_unflatten",
389 static_cast<py::object (PyTree::*)(py::iterable leaves) const>(
390 &PyTree::tree_unflatten))
391 .def("__repr__", &PyTree::py_to_str)
392 .def("__eq__", &PyTree::operator==)
393 .def("to_str", &PyTree::py_to_str)
394 .def("num_leaves", &PyTree::leaves_num);
395 }
396
397 } // namespace pytree
398 } // namespace extension
399 } // namespace executorch
400