xref: /aosp_15_r20/external/executorch/extension/pytree/__init__.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# flake8: noqa: F401
8
9import logging
10
11# Create a logger
12logger = logging.getLogger(__name__)
13logger.setLevel(logging.WARNING)
14
15
16try:
17    """
18    Internally we link the respective c++ library functions but for the OSS pip
19    build we will just use the python library for now. The python library is not
20    exactly the same so it will not work for the runtime, but it'll be fine for
21    now as in most cases the runtime will not need it.
22    """
23
24    # pyre-fixme[21]: Could not find module `executorch.extension.pytree.pybindings`.
25    # @manual=//executorch/extension/pytree:pybindings
26    from executorch.extension.pytree.pybindings import (
27        broadcast_to_and_flatten as broadcast_to_and_flatten,
28        from_str as from_str,
29        register_custom as register_custom,
30        tree_flatten as tree_flatten,
31        tree_map as tree_map,
32        tree_unflatten as tree_unflatten,
33        TreeSpec as TreeSpec,
34    )
35except:
36    logger.info(
37        "Unable to import executorch.extension.pytree, using native torch pytree instead."
38    )
39
40    from torch.utils._pytree import (
41        _broadcast_to_and_flatten,
42        _register_pytree_node,
43        tree_flatten,
44        tree_map,
45        tree_unflatten,
46        TreeSpec,
47        treespec_dumps,
48        treespec_loads,
49    )
50
51    broadcast_to_and_flatten = _broadcast_to_and_flatten
52    from_str = treespec_loads
53    register_custom = _register_pytree_node
54    TreeSpec.to_str = treespec_dumps  # pyre-ignore
55