/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include "executorch/extension/pytree/pytree.h" namespace py = pybind11; namespace executorch { namespace extension { namespace pytree { namespace { struct PyAux { py::object custom_type_context; }; using PyTreeSpec = TreeSpec; class PyTypeRegistry { public: struct PyTypeReg { explicit PyTypeReg(Kind k) : kind(k) {} Kind kind; // for custom types py::object type; // function type: object -> (children, spec_data) py::function flatten; // function type: (children, spec_data) -> object py::function unflatten; }; static const PyTypeReg* get_by_str(const std::string& pytype) { auto* registry = instance(); auto it = registry->regs_.find(pytype); return it == registry->regs_.end() ? nullptr : it->second.get(); } static const PyTypeReg* get_by_type(py::handle pytype) { return get_by_str(py::str(pytype)); } static void register_custom_type( py::object type, py::function flatten, py::function unflatten) { auto* registry = instance(); auto reg = std::make_unique(Kind::Custom); reg->type = type; reg->flatten = std::move(flatten); reg->unflatten = std::move(unflatten); std::string pytype_str = py::str(type); auto it = registry->regs_.emplace(pytype_str, std::move(reg)); if (!it.second) { assert(false); } } private: static PyTypeRegistry* instance() { static auto* registry_instance = []() -> PyTypeRegistry* { auto* registry = new PyTypeRegistry; auto add_pytype_reg = [&](const std::string& pytype, Kind kind) { registry->regs_.emplace(pytype, std::make_unique(kind)); }; add_pytype_reg("", Kind::Tuple); add_pytype_reg("", Kind::List); add_pytype_reg("", Kind::Dict); return registry; }(); return registry_instance; } std::unordered_map> regs_; }; class PyTree { PyTreeSpec spec_; static void flatten_internal( py::handle x, std::vector& leaves, PyTreeSpec& s) { const auto* reg = PyTypeRegistry::get_by_type(x.get_type()); const auto kind = [®, &x]() { if (reg) { return reg->kind; } if (py::isinstance(x) && py::hasattr(x, "_fields")) { return Kind::NamedTuple; } return Kind::Leaf; }(); switch (kind) { case Kind::List: { const size_t n = PyList_GET_SIZE(x.ptr()); s = PyTreeSpec(Kind::List, n); for (size_t i = 0; i < n; ++i) { flatten_internal(PyList_GET_ITEM(x.ptr(), i), leaves, s[i]); } break; } case Kind::Tuple: { const size_t n = PyTuple_GET_SIZE(x.ptr()); s = PyTreeSpec(Kind::Tuple, n); for (size_t i = 0; i < n; ++i) { flatten_internal(PyTuple_GET_ITEM(x.ptr(), i), leaves, s[i]); } break; } case Kind::NamedTuple: { py::tuple tuple = py::reinterpret_borrow(x); const size_t n = tuple.size(); s = PyTreeSpec(Kind::NamedTuple, n); size_t i = 0; for (py::handle entry : tuple) { flatten_internal(entry, leaves, s[i++]); } break; } case Kind::Dict: { py::dict dict = py::reinterpret_borrow(x); py::list keys = py::reinterpret_steal(PyDict_Keys(dict.ptr())); const auto n = PyList_GET_SIZE(keys.ptr()); s = PyTreeSpec(Kind::Dict, n); size_t i = 0; for (py::handle key : keys) { if (py::isinstance(key)) { s.key(i) = py::cast(key); } else if (py::isinstance(key)) { s.key(i) = py::cast(key); } else { pytree_assert(false); } flatten_internal(dict[key], leaves, s[i]); i++; } break; } case Kind::Custom: { py::tuple out = py::cast(reg->flatten(x)); if (out.size() != 2) { assert(false); } py::list children = py::cast(out[0]); const size_t n = children.size(); s = PyTreeSpec(Kind::Custom, n); s.handle->custom_type = py::str(x.get_type()); s.handle->custom_type_context = out[1]; size_t i = 0; for (py::handle pychild : children) { flatten_internal(pychild, leaves, s[i++]); } break; } case Kind::Leaf: { s = PyTreeSpec(Kind::Leaf); leaves.push_back(py::reinterpret_borrow(x)); break; } case Kind::None: pytree_assert(false); } } template py::object unflatten_internal(const PyTreeSpec& spec, T&& leaves_it) const { switch (spec.kind()) { case Kind::NamedTuple: case Kind::Tuple: { const size_t size = spec.size(); py::tuple tuple(size); for (size_t i = 0; i < size; ++i) { tuple[i] = unflatten_internal(spec[i], leaves_it); } return std::move(tuple); } case Kind::List: { const size_t size = spec.size(); py::list list(size); for (size_t i = 0; i < size; ++i) { list[i] = unflatten_internal(spec[i], leaves_it); } return std::move(list); } case Kind::Custom: { const auto& pytype_str = spec.handle->custom_type; const auto* reg = PyTypeRegistry::get_by_str(pytype_str); const size_t size = spec.size(); py::list list(size); for (size_t i = 0; i < size; ++i) { list[i] = unflatten_internal(spec[i], leaves_it); } py::object o = reg->unflatten(list, spec.handle->custom_type_context); return o; } case Kind::Dict: { const size_t size = spec.size(); py::dict dict; for (size_t i = 0; i < size; ++i) { auto& key = spec.key(i); auto py_key = [&key]() -> py::handle { switch (key.kind()) { case Key::Kind::Int: return py::cast(key.as_int()).release(); case Key::Kind::Str: return py::cast(key.as_str()).release(); case Key::Kind::None: pytree_assert(false); } pytree_assert(false); return py::none(); }(); dict[py_key] = unflatten_internal(spec[i], leaves_it); } return std::move(dict); } case Kind::Leaf: { py::object o = py::reinterpret_borrow(*std::forward(leaves_it)); leaves_it++; return o; } case Kind::None: { return py::none(); } } pytree_assert(false); } public: explicit PyTree(PyTreeSpec spec) : spec_(std::move(spec)) {} const PyTreeSpec& spec() const { return spec_; } static PyTree py_from_str(std::string spec) { return PyTree(from_str(spec)); } StrTreeSpec py_to_str() const { return to_str(spec_); } static std::pair, std::unique_ptr> tree_flatten(py::handle x) { std::vector leaves{}; PyTreeSpec spec{}; flatten_internal(x, leaves, spec); refresh_leaves_num(spec); return {std::move(leaves), std::make_unique(std::move(spec))}; } static py::object tree_unflatten(py::iterable leaves, py::object o) { return o.cast()->tree_unflatten(leaves); } template py::object tree_unflatten(T leaves) const { return unflatten_internal(spec_, leaves.begin()); } bool operator==(const PyTree& rhs) { return spec_ == rhs.spec_; } size_t leaves_num() const { return refresh_leaves_num(spec_); } }; inline std::pair, std::unique_ptr> tree_flatten( py::handle x) { return PyTree::tree_flatten(x); } inline py::object tree_unflatten(py::iterable leaves, py::object o) { return PyTree::tree_unflatten(leaves, o); } static py::object tree_map(py::function& fn, py::handle x) { auto p = tree_flatten(x); const auto& leaves = p.first; const auto& pytree = p.second; std::vector vec; for (const py::handle& h : leaves) { vec.push_back(fn(h)); } return pytree->tree_unflatten(vec); } static std::unique_ptr py_from_str(std::string spec) { return std::make_unique(from_str(spec)); } static py::object broadcast_to_and_flatten( py::object x, py::object py_tree_spec) { auto p = tree_flatten(x); const auto& x_leaves = p.first; const auto& x_spec = p.second->spec(); PyTree* tree_spec = py_tree_spec.cast(); py::list ret; struct StackItem { const PyTreeSpec* tree_spec_node; const PyTreeSpec* x_spec_node; const size_t x_leaves_offset; }; std::stack stack; stack.push({&tree_spec->spec(), &x_spec, 0u}); while (!stack.empty()) { const auto top = stack.top(); stack.pop(); if (top.x_spec_node->isLeaf()) { for (size_t i = 0; i < top.tree_spec_node->leaves_num(); ++i) { ret.append(x_leaves[top.x_leaves_offset]); } } else { const auto kind = top.tree_spec_node->kind(); if (kind != top.x_spec_node->kind()) { return py::none(); } pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind()); const size_t child_num = top.tree_spec_node->size(); if (child_num != top.x_spec_node->size()) { return py::none(); } pytree_assert(child_num == top.x_spec_node->size()); size_t x_leaves_offset = top.x_leaves_offset + top.x_spec_node->leaves_num(); auto fn_i = [&](size_t i) { x_leaves_offset -= (*top.x_spec_node)[i].leaves_num(); stack.push( {&(*top.tree_spec_node)[i], &(*top.x_spec_node)[i], x_leaves_offset}); }; if (Kind::Dict == kind) { for (size_t i = child_num - 1; i < child_num; --i) { if (top.tree_spec_node->key(i) != top.x_spec_node->key(i)) { return py::none(); } fn_i(i); } } else { for (size_t i = child_num - 1; i < child_num; --i) { fn_i(i); } } } } return std::move(ret); } } // namespace PYBIND11_MODULE(pybindings, m) { m.def("tree_flatten", &tree_flatten, py::arg("tree")); m.def("tree_unflatten", &tree_unflatten, py::arg("leaves"), py::arg("tree")); m.def("tree_map", &tree_map); m.def("from_str", &py_from_str); m.def("broadcast_to_and_flatten", &broadcast_to_and_flatten); m.def("register_custom", &PyTypeRegistry::register_custom_type); py::class_(m, "TreeSpec") .def("from_str", &PyTree::py_from_str) .def( "tree_unflatten", static_cast( &PyTree::tree_unflatten)) .def("__repr__", &PyTree::py_to_str) .def("__eq__", &PyTree::operator==) .def("to_str", &PyTree::py_to_str) .def("num_leaves", &PyTree::leaves_num); } } // namespace pytree } // namespace extension } // namespace executorch