xref: /aosp_15_r20/external/executorch/extension/pytree/pybindings.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <pybind11/pybind11.h>
10 #include <pybind11/stl.h>
11 #include <memory>
12 #include <stack>
13 
14 #include "executorch/extension/pytree/pytree.h"
15 
16 namespace py = pybind11;
17 
18 namespace executorch {
19 namespace extension {
20 namespace pytree {
21 
22 namespace {
23 
24 struct PyAux {
25   py::object custom_type_context;
26 };
27 using PyTreeSpec = TreeSpec<PyAux>;
28 
29 class PyTypeRegistry {
30  public:
31   struct PyTypeReg {
PyTypeRegexecutorch::extension::pytree::__anon59fa55ac0111::PyTypeRegistry::PyTypeReg32     explicit PyTypeReg(Kind k) : kind(k) {}
33 
34     Kind kind;
35 
36     // for custom types
37     py::object type;
38     // function type: object -> (children, spec_data)
39     py::function flatten;
40     // function type: (children, spec_data) -> object
41     py::function unflatten;
42   };
43 
get_by_str(const std::string & pytype)44   static const PyTypeReg* get_by_str(const std::string& pytype) {
45     auto* registry = instance();
46     auto it = registry->regs_.find(pytype);
47     return it == registry->regs_.end() ? nullptr : it->second.get();
48   }
49 
get_by_type(py::handle pytype)50   static const PyTypeReg* get_by_type(py::handle pytype) {
51     return get_by_str(py::str(pytype));
52   }
53 
register_custom_type(py::object type,py::function flatten,py::function unflatten)54   static void register_custom_type(
55       py::object type,
56       py::function flatten,
57       py::function unflatten) {
58     auto* registry = instance();
59     auto reg = std::make_unique<PyTypeReg>(Kind::Custom);
60     reg->type = type;
61     reg->flatten = std::move(flatten);
62     reg->unflatten = std::move(unflatten);
63     std::string pytype_str = py::str(type);
64     auto it = registry->regs_.emplace(pytype_str, std::move(reg));
65     if (!it.second) {
66       assert(false);
67     }
68   }
69 
70  private:
instance()71   static PyTypeRegistry* instance() {
72     static auto* registry_instance = []() -> PyTypeRegistry* {
73       auto* registry = new PyTypeRegistry;
74 
75       auto add_pytype_reg = [&](const std::string& pytype, Kind kind) {
76         registry->regs_.emplace(pytype, std::make_unique<PyTypeReg>(kind));
77       };
78 
79       add_pytype_reg("<class 'tuple'>", Kind::Tuple);
80       add_pytype_reg("<class 'list'>", Kind::List);
81       add_pytype_reg("<class 'dict'>", Kind::Dict);
82 
83       return registry;
84     }();
85 
86     return registry_instance;
87   }
88   std::unordered_map<std::string, std::unique_ptr<PyTypeReg>> regs_;
89 };
90 
91 class PyTree {
92   PyTreeSpec spec_;
93 
flatten_internal(py::handle x,std::vector<py::object> & leaves,PyTreeSpec & s)94   static void flatten_internal(
95       py::handle x,
96       std::vector<py::object>& leaves,
97       PyTreeSpec& s) {
98     const auto* reg = PyTypeRegistry::get_by_type(x.get_type());
99     const auto kind = [&reg, &x]() {
100       if (reg) {
101         return reg->kind;
102       }
103       if (py::isinstance<py::tuple>(x) && py::hasattr(x, "_fields")) {
104         return Kind::NamedTuple;
105       }
106       return Kind::Leaf;
107     }();
108     switch (kind) {
109       case Kind::List: {
110         const size_t n = PyList_GET_SIZE(x.ptr());
111         s = PyTreeSpec(Kind::List, n);
112         for (size_t i = 0; i < n; ++i) {
113           flatten_internal(PyList_GET_ITEM(x.ptr(), i), leaves, s[i]);
114         }
115         break;
116       }
117       case Kind::Tuple: {
118         const size_t n = PyTuple_GET_SIZE(x.ptr());
119         s = PyTreeSpec(Kind::Tuple, n);
120         for (size_t i = 0; i < n; ++i) {
121           flatten_internal(PyTuple_GET_ITEM(x.ptr(), i), leaves, s[i]);
122         }
123         break;
124       }
125       case Kind::NamedTuple: {
126         py::tuple tuple = py::reinterpret_borrow<py::tuple>(x);
127         const size_t n = tuple.size();
128         s = PyTreeSpec(Kind::NamedTuple, n);
129         size_t i = 0;
130         for (py::handle entry : tuple) {
131           flatten_internal(entry, leaves, s[i++]);
132         }
133         break;
134       }
135       case Kind::Dict: {
136         py::dict dict = py::reinterpret_borrow<py::dict>(x);
137         py::list keys =
138             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
139         const auto n = PyList_GET_SIZE(keys.ptr());
140         s = PyTreeSpec(Kind::Dict, n);
141         size_t i = 0;
142         for (py::handle key : keys) {
143           if (py::isinstance<py::str>(key)) {
144             s.key(i) = py::cast<std::string>(key);
145           } else if (py::isinstance<py::int_>(key)) {
146             s.key(i) = py::cast<int32_t>(key);
147           } else {
148             pytree_assert(false);
149           }
150 
151           flatten_internal(dict[key], leaves, s[i]);
152           i++;
153         }
154         break;
155       }
156       case Kind::Custom: {
157         py::tuple out = py::cast<py::tuple>(reg->flatten(x));
158         if (out.size() != 2) {
159           assert(false);
160         }
161         py::list children = py::cast<py::list>(out[0]);
162         const size_t n = children.size();
163         s = PyTreeSpec(Kind::Custom, n);
164         s.handle->custom_type = py::str(x.get_type());
165         s.handle->custom_type_context = out[1];
166         size_t i = 0;
167         for (py::handle pychild : children) {
168           flatten_internal(pychild, leaves, s[i++]);
169         }
170         break;
171       }
172       case Kind::Leaf: {
173         s = PyTreeSpec(Kind::Leaf);
174         leaves.push_back(py::reinterpret_borrow<py::object>(x));
175         break;
176       }
177       case Kind::None:
178         pytree_assert(false);
179     }
180   }
181 
182   template <typename T>
unflatten_internal(const PyTreeSpec & spec,T && leaves_it) const183   py::object unflatten_internal(const PyTreeSpec& spec, T&& leaves_it) const {
184     switch (spec.kind()) {
185       case Kind::NamedTuple:
186       case Kind::Tuple: {
187         const size_t size = spec.size();
188         py::tuple tuple(size);
189         for (size_t i = 0; i < size; ++i) {
190           tuple[i] = unflatten_internal(spec[i], leaves_it);
191         }
192         return std::move(tuple);
193       }
194       case Kind::List: {
195         const size_t size = spec.size();
196         py::list list(size);
197         for (size_t i = 0; i < size; ++i) {
198           list[i] = unflatten_internal(spec[i], leaves_it);
199         }
200         return std::move(list);
201       }
202       case Kind::Custom: {
203         const auto& pytype_str = spec.handle->custom_type;
204         const auto* reg = PyTypeRegistry::get_by_str(pytype_str);
205         const size_t size = spec.size();
206         py::list list(size);
207         for (size_t i = 0; i < size; ++i) {
208           list[i] = unflatten_internal(spec[i], leaves_it);
209         }
210         py::object o = reg->unflatten(list, spec.handle->custom_type_context);
211         return o;
212       }
213       case Kind::Dict: {
214         const size_t size = spec.size();
215         py::dict dict;
216         for (size_t i = 0; i < size; ++i) {
217           auto& key = spec.key(i);
218           auto py_key = [&key]() -> py::handle {
219             switch (key.kind()) {
220               case Key::Kind::Int:
221                 return py::cast(key.as_int()).release();
222               case Key::Kind::Str:
223                 return py::cast(key.as_str()).release();
224               case Key::Kind::None:
225                 pytree_assert(false);
226             }
227             pytree_assert(false);
228             return py::none();
229           }();
230           dict[py_key] = unflatten_internal(spec[i], leaves_it);
231         }
232         return std::move(dict);
233       }
234       case Kind::Leaf: {
235         py::object o =
236             py::reinterpret_borrow<py::object>(*std::forward<T>(leaves_it));
237         leaves_it++;
238         return o;
239       }
240       case Kind::None: {
241         return py::none();
242       }
243     }
244     pytree_assert(false);
245   }
246 
247  public:
PyTree(PyTreeSpec spec)248   explicit PyTree(PyTreeSpec spec) : spec_(std::move(spec)) {}
249 
spec() const250   const PyTreeSpec& spec() const {
251     return spec_;
252   }
253 
py_from_str(std::string spec)254   static PyTree py_from_str(std::string spec) {
255     return PyTree(from_str<PyAux>(spec));
256   }
257 
py_to_str() const258   StrTreeSpec py_to_str() const {
259     return to_str(spec_);
260   }
261 
262   static std::pair<std::vector<py::object>, std::unique_ptr<PyTree>>
tree_flatten(py::handle x)263   tree_flatten(py::handle x) {
264     std::vector<py::object> leaves{};
265     PyTreeSpec spec{};
266     flatten_internal(x, leaves, spec);
267     refresh_leaves_num(spec);
268     return {std::move(leaves), std::make_unique<PyTree>(std::move(spec))};
269   }
270 
tree_unflatten(py::iterable leaves,py::object o)271   static py::object tree_unflatten(py::iterable leaves, py::object o) {
272     return o.cast<PyTree*>()->tree_unflatten(leaves);
273   }
274 
275   template <typename T>
tree_unflatten(T leaves) const276   py::object tree_unflatten(T leaves) const {
277     return unflatten_internal(spec_, leaves.begin());
278   }
279 
operator ==(const PyTree & rhs)280   bool operator==(const PyTree& rhs) {
281     return spec_ == rhs.spec_;
282   }
283 
leaves_num() const284   size_t leaves_num() const {
285     return refresh_leaves_num(spec_);
286   }
287 };
288 
tree_flatten(py::handle x)289 inline std::pair<std::vector<py::object>, std::unique_ptr<PyTree>> tree_flatten(
290     py::handle x) {
291   return PyTree::tree_flatten(x);
292 }
293 
tree_unflatten(py::iterable leaves,py::object o)294 inline py::object tree_unflatten(py::iterable leaves, py::object o) {
295   return PyTree::tree_unflatten(leaves, o);
296 }
297 
tree_map(py::function & fn,py::handle x)298 static py::object tree_map(py::function& fn, py::handle x) {
299   auto p = tree_flatten(x);
300   const auto& leaves = p.first;
301   const auto& pytree = p.second;
302   std::vector<py::handle> vec;
303   for (const py::handle& h : leaves) {
304     vec.push_back(fn(h));
305   }
306   return pytree->tree_unflatten(vec);
307 }
308 
py_from_str(std::string spec)309 static std::unique_ptr<PyTree> py_from_str(std::string spec) {
310   return std::make_unique<PyTree>(from_str<PyAux>(spec));
311 }
312 
broadcast_to_and_flatten(py::object x,py::object py_tree_spec)313 static py::object broadcast_to_and_flatten(
314     py::object x,
315     py::object py_tree_spec) {
316   auto p = tree_flatten(x);
317   const auto& x_leaves = p.first;
318   const auto& x_spec = p.second->spec();
319 
320   PyTree* tree_spec = py_tree_spec.cast<PyTree*>();
321 
322   py::list ret;
323   struct StackItem {
324     const PyTreeSpec* tree_spec_node;
325     const PyTreeSpec* x_spec_node;
326     const size_t x_leaves_offset;
327   };
328   std::stack<StackItem> stack;
329   stack.push({&tree_spec->spec(), &x_spec, 0u});
330   while (!stack.empty()) {
331     const auto top = stack.top();
332     stack.pop();
333     if (top.x_spec_node->isLeaf()) {
334       for (size_t i = 0; i < top.tree_spec_node->leaves_num(); ++i) {
335         ret.append(x_leaves[top.x_leaves_offset]);
336       }
337     } else {
338       const auto kind = top.tree_spec_node->kind();
339       if (kind != top.x_spec_node->kind()) {
340         return py::none();
341       }
342       pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343       const size_t child_num = top.tree_spec_node->size();
344       if (child_num != top.x_spec_node->size()) {
345         return py::none();
346       }
347       pytree_assert(child_num == top.x_spec_node->size());
348 
349       size_t x_leaves_offset =
350           top.x_leaves_offset + top.x_spec_node->leaves_num();
351       auto fn_i = [&](size_t i) {
352         x_leaves_offset -= (*top.x_spec_node)[i].leaves_num();
353         stack.push(
354             {&(*top.tree_spec_node)[i],
355              &(*top.x_spec_node)[i],
356              x_leaves_offset});
357       };
358       if (Kind::Dict == kind) {
359         for (size_t i = child_num - 1; i < child_num; --i) {
360           if (top.tree_spec_node->key(i) != top.x_spec_node->key(i)) {
361             return py::none();
362           }
363           fn_i(i);
364         }
365       } else {
366         for (size_t i = child_num - 1; i < child_num; --i) {
367           fn_i(i);
368         }
369       }
370     }
371   }
372   return std::move(ret);
373 }
374 
375 } // namespace
376 
PYBIND11_MODULE(pybindings,m)377 PYBIND11_MODULE(pybindings, m) {
378   m.def("tree_flatten", &tree_flatten, py::arg("tree"));
379   m.def("tree_unflatten", &tree_unflatten, py::arg("leaves"), py::arg("tree"));
380   m.def("tree_map", &tree_map);
381   m.def("from_str", &py_from_str);
382   m.def("broadcast_to_and_flatten", &broadcast_to_and_flatten);
383   m.def("register_custom", &PyTypeRegistry::register_custom_type);
384 
385   py::class_<PyTree>(m, "TreeSpec")
386       .def("from_str", &PyTree::py_from_str)
387       .def(
388           "tree_unflatten",
389           static_cast<py::object (PyTree::*)(py::iterable leaves) const>(
390               &PyTree::tree_unflatten))
391       .def("__repr__", &PyTree::py_to_str)
392       .def("__eq__", &PyTree::operator==)
393       .def("to_str", &PyTree::py_to_str)
394       .def("num_leaves", &PyTree::leaves_num);
395 }
396 
397 } // namespace pytree
398 } // namespace extension
399 } // namespace executorch
400