xref: /aosp_15_r20/external/pytorch/torch/utils/_cxx_pytree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Contains utility functions for working with nested python data structures.
3
4A *pytree* is Python nested data structure. It is a tree in the sense that
5nodes are Python collections (e.g., list, tuple, dict) and the leaves are
6Python values. Furthermore, a pytree should not contain reference cycles.
7
8pytrees are useful for working with nested collections of Tensors. For example,
9one can use `tree_map` to map a function over all Tensors inside some nested
10collection of Tensors and `tree_leaves` to get a flat list of all Tensors
11inside some nested collection. pytrees are helpful for implementing nested
12collection support for PyTorch APIs.
13"""
14
15import functools
16import sys
17import types
18from typing import (
19    Any,
20    Callable,
21    Iterable,
22    List,
23    Optional,
24    overload,
25    Tuple,
26    Type,
27    TypeVar,
28    Union,
29)
30from typing_extensions import deprecated
31
32import optree
33from optree import PyTreeSpec  # direct import for type annotations
34
35import torch.utils._pytree as _pytree
36from torch.utils._pytree import KeyEntry
37
38
39__all__ = [
40    "PyTree",
41    "Context",
42    "FlattenFunc",
43    "UnflattenFunc",
44    "DumpableContext",
45    "ToDumpableContextFn",
46    "FromDumpableContextFn",
47    "TreeSpec",
48    "LeafSpec",
49    "keystr",
50    "key_get",
51    "register_pytree_node",
52    "tree_flatten",
53    "tree_flatten_with_path",
54    "tree_unflatten",
55    "tree_iter",
56    "tree_leaves",
57    "tree_leaves_with_path",
58    "tree_structure",
59    "tree_map",
60    "tree_map_with_path",
61    "tree_map_",
62    "tree_map_only",
63    "tree_map_only_",
64    "tree_all",
65    "tree_any",
66    "tree_all_only",
67    "tree_any_only",
68    "treespec_dumps",
69    "treespec_loads",
70    "treespec_pprint",
71]
72
73
74T = TypeVar("T")
75S = TypeVar("S")
76U = TypeVar("U")
77R = TypeVar("R")
78
79
80Context = Any
81PyTree = Any
82TreeSpec = PyTreeSpec
83FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
84UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
85OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
86DumpableContext = Any  # Any json dumpable text
87ToDumpableContextFn = Callable[[Context], DumpableContext]
88FromDumpableContextFn = Callable[[DumpableContext], Context]
89KeyPath = Tuple[KeyEntry, ...]
90FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
91
92
93def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
94    @functools.wraps(func)
95    def wrapped(*args: Any, **kwargs: Any) -> Any:
96        return func(*reversed(args), **kwargs)
97
98    return wrapped
99
100
101def register_pytree_node(
102    cls: Type[Any],
103    flatten_fn: FlattenFunc,
104    unflatten_fn: UnflattenFunc,
105    *,
106    serialized_type_name: Optional[str] = None,
107    to_dumpable_context: Optional[ToDumpableContextFn] = None,
108    from_dumpable_context: Optional[FromDumpableContextFn] = None,
109    flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
110) -> None:
111    """Register a container-like type as pytree node.
112
113    Args:
114        cls (type): A Python type to treat as an internal pytree node.
115        flatten_fn (callable): A function to be used during flattening, taking an instance of
116            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
117            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
118            passed to the ``unflatten_fn``.
119        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
120            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
121            The function should return an instance of ``cls``.
122        serialized_type_name (str, optional): A keyword argument used to specify the fully
123            qualified name used when serializing the tree spec.
124        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
125            to convert the context of the pytree to a custom json dumpable representation. This is
126            used for json serialization, which is being used in :mod:`torch.export` right now.
127        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
128            how to convert the custom json dumpable representation of the context back to the
129            original context. This is used for json deserialization, which is being used in
130            :mod:`torch.export` right now.
131
132    Example::
133
134        >>> # xdoctest: +SKIP
135        >>> # Registry a Python type with lambda functions
136        >>> register_pytree_node(
137        ...     set,
138        ...     lambda s: (sorted(s), None, None),
139        ...     lambda children, _: set(children),
140        ... )
141    """
142    if flatten_with_keys_fn is not None:
143        raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
144
145    _private_register_pytree_node(
146        cls,
147        flatten_fn,
148        unflatten_fn,
149        serialized_type_name=serialized_type_name,
150        to_dumpable_context=to_dumpable_context,
151        from_dumpable_context=from_dumpable_context,
152    )
153
154    from . import _pytree as python
155
156    python._private_register_pytree_node(
157        cls,
158        flatten_fn,
159        unflatten_fn,
160        serialized_type_name=serialized_type_name,
161        to_dumpable_context=to_dumpable_context,
162        from_dumpable_context=from_dumpable_context,
163    )
164
165
166@deprecated(
167    "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. "
168    "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.",
169    category=FutureWarning,
170)
171def _register_pytree_node(
172    cls: Type[Any],
173    flatten_fn: FlattenFunc,
174    unflatten_fn: UnflattenFunc,
175    *,
176    serialized_type_name: Optional[str] = None,
177    to_dumpable_context: Optional[ToDumpableContextFn] = None,
178    from_dumpable_context: Optional[FromDumpableContextFn] = None,
179) -> None:
180    """Register a container-like type as pytree node for the C++ pytree only.
181
182    The ``namespace`` argument is used to avoid collisions that occur when different libraries
183    register the same Python type with different behaviors. It is recommended to add a unique prefix
184    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
185    the same class in different namespaces for different use cases.
186
187    .. warning::
188        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
189        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
190        prevent accidental collisions between different libraries that may register the same type.
191
192    Args:
193        cls (type): A Python type to treat as an internal pytree node.
194        flatten_fn (callable): A function to be used during flattening, taking an instance of
195            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
196            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
197            passed to the ``unflatten_fn``.
198        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
199            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
200            The function should return an instance of ``cls``.
201        serialized_type_name (str, optional): A keyword argument used to specify the fully
202            qualified name used when serializing the tree spec.
203        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
204            to convert the context of the pytree to a custom json dumpable representation. This is
205            used for json serialization, which is being used in :mod:`torch.export` right now.
206        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
207            how to convert the custom json dumpable representation of the context back to the
208            original context. This is used for json deserialization, which is being used in
209            :mod:`torch.export` right now.
210    """
211
212    _private_register_pytree_node(
213        cls,
214        flatten_fn,
215        unflatten_fn,
216        serialized_type_name=serialized_type_name,
217        to_dumpable_context=to_dumpable_context,
218        from_dumpable_context=from_dumpable_context,
219    )
220
221
222def _private_register_pytree_node(
223    cls: Type[Any],
224    flatten_fn: FlattenFunc,
225    unflatten_fn: UnflattenFunc,
226    *,
227    serialized_type_name: Optional[str] = None,
228    to_dumpable_context: Optional[ToDumpableContextFn] = None,
229    from_dumpable_context: Optional[FromDumpableContextFn] = None,
230) -> None:
231    """This is an internal function that is used to register a pytree node type
232    for the C++ pytree only. End-users should use :func:`register_pytree_node`
233    instead.
234    """
235    # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
236    # PyStructSequence types
237    if not optree.is_structseq_class(cls):
238        optree.register_pytree_node(
239            cls,
240            flatten_fn,
241            _reverse_args(unflatten_fn),
242            namespace="torch",
243        )
244
245
246def tree_flatten(
247    tree: PyTree,
248    is_leaf: Optional[Callable[[PyTree], bool]] = None,
249) -> Tuple[List[Any], TreeSpec]:
250    """Flatten a pytree.
251
252    See also :func:`tree_unflatten`.
253
254    The flattening order (i.e., the order of elements in the output list) is deterministic,
255    corresponding to a left-to-right depth-first tree traversal.
256
257    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
258    >>> tree_flatten(tree)
259    ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
260    >>> tree_flatten(1)
261    ([1], PyTreeSpec(*, NoneIsLeaf))
262    >>> tree_flatten(None)
263    ([None], PyTreeSpec(*, NoneIsLeaf))
264
265    For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
266    dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
267    if you want to keep the keys in the insertion order.
268
269    >>> from collections import OrderedDict
270    >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
271    >>> tree_flatten(tree)
272    ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
273
274    Args:
275        tree (pytree): A pytree to flatten.
276        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
277            flattening step. The function should have a single argument with signature
278            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
279            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
280            leaf or not. If the function is not specified, the default pytree registry will be used.
281
282    Returns:
283        A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
284        second element is a treespec representing the structure of the pytree.
285    """
286    return optree.tree_flatten(  # type: ignore[return-value]
287        tree,
288        is_leaf=is_leaf,
289        none_is_leaf=True,
290        namespace="torch",
291    )
292
293
294def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
295    """Reconstruct a pytree from the treespec and the leaves.
296
297    The inverse of :func:`tree_flatten`.
298
299    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
300    >>> leaves, treespec = tree_flatten(tree)
301    >>> tree == tree_unflatten(leaves, treespec)
302    True
303
304    Args:
305        leaves (iterable): The list of leaves to use for reconstruction. The list must match the
306            number of leaves of the treespec.
307        treespec (TreeSpec): The treespec to reconstruct.
308
309    Returns:
310        The reconstructed pytree, containing the ``leaves`` placed in the structure described by
311        ``treespec``.
312    """
313    if not isinstance(treespec, TreeSpec):
314        raise TypeError(
315            f"tree_unflatten(values, spec): Expected `spec` to be instance of "
316            f"TreeSpec but got item of type {type(treespec)}."
317        )
318    return optree.tree_unflatten(treespec, leaves)  # type: ignore[arg-type]
319
320
321def tree_iter(
322    tree: PyTree,
323    is_leaf: Optional[Callable[[PyTree], bool]] = None,
324) -> Iterable[Any]:
325    """Get an iterator over the leaves of a pytree.
326
327    See also :func:`tree_flatten`.
328
329    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
330    >>> list(tree_iter(tree))
331    [1, 2, 3, 4, None, 5]
332    >>> list(tree_iter(1))
333    [1]
334    >>> list(tree_iter(None))
335    [None]
336
337    Args:
338        tree (pytree): A pytree to flatten.
339        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
340            flattening step. The function should have a single argument with signature
341            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
342            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
343            leaf or not. If the function is not specified, the default pytree registry will be used.
344
345    Returns:
346        An iterator over the leaf values.
347    """
348    return optree.tree_iter(
349        tree,
350        is_leaf=is_leaf,
351        none_is_leaf=True,
352        namespace="torch",
353    )
354
355
356def tree_leaves(
357    tree: PyTree,
358    is_leaf: Optional[Callable[[PyTree], bool]] = None,
359) -> List[Any]:
360    """Get the leaves of a pytree.
361
362    See also :func:`tree_flatten`.
363
364    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
365    >>> tree_leaves(tree)
366    [1, 2, 3, 4, None, 5]
367    >>> tree_leaves(1)
368    [1]
369    >>> tree_leaves(None)
370    [None]
371
372    Args:
373        tree (pytree): A pytree to flatten.
374        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
375            flattening step. The function should have a single argument with signature
376            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
377            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
378            leaf or not. If the function is not specified, the default pytree registry will be used.
379
380    Returns:
381        A list of leaf values.
382    """
383    return optree.tree_leaves(
384        tree,
385        is_leaf=is_leaf,
386        none_is_leaf=True,
387        namespace="torch",
388    )
389
390
391def tree_structure(
392    tree: PyTree,
393    is_leaf: Optional[Callable[[PyTree], bool]] = None,
394) -> TreeSpec:
395    """Get the treespec for a pytree.
396
397    See also :func:`tree_flatten`.
398
399    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
400    >>> tree_structure(tree)
401    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
402    >>> tree_structure(1)
403    PyTreeSpec(*, NoneIsLeaf)
404    >>> tree_structure(None)
405    PyTreeSpec(*, NoneIsLeaf)
406
407    Args:
408        tree (pytree): A pytree to flatten.
409        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
410            flattening step. The function should have a single argument with signature
411            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
412            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
413            leaf or not. If the function is not specified, the default pytree registry will be used.
414
415    Returns:
416        A treespec object representing the structure of the pytree.
417    """
418    return optree.tree_structure(  # type: ignore[return-value]
419        tree,
420        is_leaf=is_leaf,
421        none_is_leaf=True,
422        namespace="torch",
423    )
424
425
426def tree_map(
427    func: Callable[..., Any],
428    tree: PyTree,
429    *rests: PyTree,
430    is_leaf: Optional[Callable[[PyTree], bool]] = None,
431) -> PyTree:
432    """Map a multi-input function over pytree args to produce a new pytree.
433
434    See also :func:`tree_map_`.
435
436    >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
437    {'x': 8, 'y': (43, 65)}
438    >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
439    {'x': False, 'y': (False, False), 'z': True}
440
441    If multiple inputs are given, the structure of the tree is taken from the first input;
442    subsequent inputs need only have ``tree`` as a prefix:
443
444    >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
445    [[5, 7, 9], [6, 1, 2]]
446
447    Args:
448        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
449            corresponding leaves of the pytrees.
450        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
451            argument to function ``func``.
452        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
453            ``tree`` or has ``tree`` as a prefix.
454        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
455            flattening step. The function should have a single argument with signature
456            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
457            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
458            leaf or not. If the function is not specified, the default pytree registry will be used.
459
460    Returns:
461        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
462        ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
463        is the tuple of values at corresponding nodes in ``rests``.
464    """
465    return optree.tree_map(
466        func,
467        tree,
468        *rests,
469        is_leaf=is_leaf,
470        none_is_leaf=True,
471        namespace="torch",
472    )
473
474
475def tree_map_(
476    func: Callable[..., Any],
477    tree: PyTree,
478    *rests: PyTree,
479    is_leaf: Optional[Callable[[PyTree], bool]] = None,
480) -> PyTree:
481    """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
482
483    See also :func:`tree_map`.
484
485    Args:
486        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
487            corresponding leaves of the pytrees.
488        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
489            argument to function ``func``.
490        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
491            ``tree`` or has ``tree`` as a prefix.
492        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
493            flattening step. The function should have a single argument with signature
494            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
495            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
496            leaf or not. If the function is not specified, the default pytree registry will be used.
497
498    Returns:
499        The original ``tree`` with the value at each leaf is given by the side-effect of function
500        ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
501        in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
502    """
503    return optree.tree_map_(
504        func,
505        tree,
506        *rests,
507        is_leaf=is_leaf,
508        none_is_leaf=True,
509        namespace="torch",
510    )
511
512
513Type2 = Tuple[Type[T], Type[S]]
514Type3 = Tuple[Type[T], Type[S], Type[U]]
515if sys.version_info >= (3, 10):
516    TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
517else:
518    TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
519
520Fn2 = Callable[[Union[T, S]], R]
521Fn3 = Callable[[Union[T, S, U]], R]
522Fn = Callable[[T], R]
523FnAny = Callable[[Any], R]
524
525MapOnlyFn = Callable[[T], Callable[[Any], Any]]
526
527
528# These specializations help with type inference on the lambda passed to this
529# function
530@overload
531def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
532    ...
533
534
535@overload
536def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
537    ...
538
539
540@overload
541def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
542    ...
543
544
545# This specialization is needed for the implementations below that call
546@overload
547def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
548    ...
549
550
551@overload
552def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
553    ...
554
555
556def map_only(
557    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
558) -> MapOnlyFn[FnAny[Any]]:
559    """
560    Suppose you are writing a tree_map over tensors, leaving everything
561    else unchanged.  Ordinarily you would have to write:
562
563        def go(t):
564            if isinstance(t, Tensor):
565                return ...
566            else:
567                return t
568
569    With this function, you only need to write:
570
571        @map_only(Tensor)
572        def go(t):
573            return ...
574
575    You can also directly use 'tree_map_only'
576    """
577    if isinstance(__type_or_types_or_pred, (type, tuple)) or (
578        sys.version_info >= (3, 10)
579        and isinstance(__type_or_types_or_pred, types.UnionType)
580    ):
581
582        def pred(x: Any) -> bool:
583            return isinstance(x, __type_or_types_or_pred)  # type: ignore[arg-type]
584
585    elif callable(__type_or_types_or_pred):
586        pred = __type_or_types_or_pred  # type: ignore[assignment]
587    else:
588        raise TypeError("Argument must be a type, a tuple of types, or a callable.")
589
590    def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
591        @functools.wraps(func)
592        def wrapped(x: T) -> Any:
593            if pred(x):
594                return func(x)
595            return x
596
597        return wrapped
598
599    return wrapper
600
601
602@overload
603def tree_map_only(
604    __type_or_types_or_pred: Type[T],
605    func: Fn[T, Any],
606    tree: PyTree,
607    is_leaf: Optional[Callable[[PyTree], bool]] = None,
608) -> PyTree:
609    ...
610
611
612@overload
613def tree_map_only(
614    __type_or_types_or_pred: Type2[T, S],
615    func: Fn2[T, S, Any],
616    tree: PyTree,
617    is_leaf: Optional[Callable[[PyTree], bool]] = None,
618) -> PyTree:
619    ...
620
621
622@overload
623def tree_map_only(
624    __type_or_types_or_pred: Type3[T, S, U],
625    func: Fn3[T, S, U, Any],
626    tree: PyTree,
627    is_leaf: Optional[Callable[[PyTree], bool]] = None,
628) -> PyTree:
629    ...
630
631
632@overload
633def tree_map_only(
634    __type_or_types_or_pred: Callable[[Any], bool],
635    func: FnAny[Any],
636    tree: PyTree,
637    is_leaf: Optional[Callable[[PyTree], bool]] = None,
638) -> PyTree:
639    ...
640
641
642def tree_map_only(
643    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
644    func: FnAny[Any],
645    tree: PyTree,
646    is_leaf: Optional[Callable[[PyTree], bool]] = None,
647) -> PyTree:
648    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
649
650
651@overload
652def tree_map_only_(
653    __type_or_types_or_pred: Type[T],
654    func: Fn[T, Any],
655    tree: PyTree,
656    is_leaf: Optional[Callable[[PyTree], bool]] = None,
657) -> PyTree:
658    ...
659
660
661@overload
662def tree_map_only_(
663    __type_or_types_or_pred: Type2[T, S],
664    func: Fn2[T, S, Any],
665    tree: PyTree,
666    is_leaf: Optional[Callable[[PyTree], bool]] = None,
667) -> PyTree:
668    ...
669
670
671@overload
672def tree_map_only_(
673    __type_or_types_or_pred: Type3[T, S, U],
674    func: Fn3[T, S, U, Any],
675    tree: PyTree,
676    is_leaf: Optional[Callable[[PyTree], bool]] = None,
677) -> PyTree:
678    ...
679
680
681@overload
682def tree_map_only_(
683    __type_or_types_or_pred: Callable[[Any], bool],
684    func: FnAny[Any],
685    tree: PyTree,
686    is_leaf: Optional[Callable[[PyTree], bool]] = None,
687) -> PyTree:
688    ...
689
690
691def tree_map_only_(
692    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
693    func: FnAny[Any],
694    tree: PyTree,
695    is_leaf: Optional[Callable[[PyTree], bool]] = None,
696) -> PyTree:
697    return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
698
699
700def tree_all(
701    pred: Callable[[Any], bool],
702    tree: PyTree,
703    is_leaf: Optional[Callable[[PyTree], bool]] = None,
704) -> bool:
705    flat_args = tree_iter(tree, is_leaf=is_leaf)
706    return all(map(pred, flat_args))
707
708
709def tree_any(
710    pred: Callable[[Any], bool],
711    tree: PyTree,
712    is_leaf: Optional[Callable[[PyTree], bool]] = None,
713) -> bool:
714    flat_args = tree_iter(tree, is_leaf=is_leaf)
715    return any(map(pred, flat_args))
716
717
718@overload
719def tree_all_only(
720    __type_or_types: Type[T],
721    pred: Fn[T, bool],
722    tree: PyTree,
723    is_leaf: Optional[Callable[[PyTree], bool]] = None,
724) -> bool:
725    ...
726
727
728@overload
729def tree_all_only(
730    __type_or_types: Type2[T, S],
731    pred: Fn2[T, S, bool],
732    tree: PyTree,
733    is_leaf: Optional[Callable[[PyTree], bool]] = None,
734) -> bool:
735    ...
736
737
738@overload
739def tree_all_only(
740    __type_or_types: Type3[T, S, U],
741    pred: Fn3[T, S, U, bool],
742    tree: PyTree,
743    is_leaf: Optional[Callable[[PyTree], bool]] = None,
744) -> bool:
745    ...
746
747
748def tree_all_only(
749    __type_or_types: TypeAny,
750    pred: FnAny[bool],
751    tree: PyTree,
752    is_leaf: Optional[Callable[[PyTree], bool]] = None,
753) -> bool:
754    flat_args = tree_iter(tree, is_leaf=is_leaf)
755    return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
756
757
758@overload
759def tree_any_only(
760    __type_or_types: Type[T],
761    pred: Fn[T, bool],
762    tree: PyTree,
763    is_leaf: Optional[Callable[[PyTree], bool]] = None,
764) -> bool:
765    ...
766
767
768@overload
769def tree_any_only(
770    __type_or_types: Type2[T, S],
771    pred: Fn2[T, S, bool],
772    tree: PyTree,
773    is_leaf: Optional[Callable[[PyTree], bool]] = None,
774) -> bool:
775    ...
776
777
778@overload
779def tree_any_only(
780    __type_or_types: Type3[T, S, U],
781    pred: Fn3[T, S, U, bool],
782    tree: PyTree,
783    is_leaf: Optional[Callable[[PyTree], bool]] = None,
784) -> bool:
785    ...
786
787
788def tree_any_only(
789    __type_or_types: TypeAny,
790    pred: FnAny[bool],
791    tree: PyTree,
792    is_leaf: Optional[Callable[[PyTree], bool]] = None,
793) -> bool:
794    flat_args = tree_iter(tree, is_leaf=is_leaf)
795    return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
796
797
798def broadcast_prefix(
799    prefix_tree: PyTree,
800    full_tree: PyTree,
801    is_leaf: Optional[Callable[[PyTree], bool]] = None,
802) -> List[Any]:
803    """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
804
805    If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
806    constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
807
808    This function returns a list of leaves with the same size as ``full_tree``. The leaves are
809    replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
810    subtree in ``full_tree``.
811
812    >>> broadcast_prefix(1, [1, 2, 3])
813    [1, 1, 1]
814    >>> broadcast_prefix([1, 2, 3], [1, 2, 3])
815    [1, 2, 3]
816    >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
817    Traceback (most recent call last):
818        ...
819    ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
820    >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
821    [1, 2, 3, 3]
822    >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
823    [1, 2, 3, 3, 3, 3]
824
825    Args:
826        prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
827        full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
828        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
829            flattening step. The function should have a single argument with signature
830            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
831            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
832            leaf or not. If the function is not specified, the default pytree registry will be used.
833
834    Returns:
835        A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
836    """
837    return optree.broadcast_prefix(
838        prefix_tree,
839        full_tree,
840        is_leaf=is_leaf,
841        none_is_leaf=True,
842        namespace="torch",
843    )
844
845
846# Broadcasts a pytree to the provided TreeSpec and returns the flattened
847# values. If this is not possible, then this function returns None.
848#
849# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
850# would return [0, 0]. This is useful for part of the vmap implementation:
851# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
852# broadcastable to the tree structure of `inputs` and we use
853# _broadcast_to_and_flatten to check this.
854def _broadcast_to_and_flatten(
855    tree: PyTree,
856    treespec: TreeSpec,
857    is_leaf: Optional[Callable[[PyTree], bool]] = None,
858) -> Optional[List[Any]]:
859    assert isinstance(treespec, TreeSpec)
860    full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
861    try:
862        return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
863    except ValueError:
864        return None
865
866
867def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
868    """Serialize a treespec to a JSON string."""
869    if not isinstance(treespec, TreeSpec):
870        raise TypeError(
871            f"treespec_dumps(spec): Expected `spec` to be instance of "
872            f"TreeSpec but got item of type {type(treespec)}."
873        )
874    from ._pytree import (
875        tree_structure as _tree_structure,
876        treespec_dumps as _treespec_dumps,
877    )
878
879    orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
880    return _treespec_dumps(orig_treespec, protocol=protocol)
881
882
883def treespec_loads(serialized: str) -> TreeSpec:
884    """Deserialize a treespec from a JSON string."""
885    from ._pytree import (
886        tree_unflatten as _tree_unflatten,
887        treespec_loads as _treespec_loads,
888    )
889
890    orig_treespec = _treespec_loads(serialized)
891    dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
892    treespec = tree_structure(dummy_tree)
893    return treespec
894
895
896class _DummyLeaf:
897    def __repr__(self) -> str:
898        return "*"
899
900
901def treespec_pprint(treespec: TreeSpec) -> str:
902    dummy_tree = tree_unflatten(
903        [_DummyLeaf() for _ in range(treespec.num_leaves)],
904        treespec,
905    )
906    return repr(dummy_tree)
907
908
909class LeafSpecMeta(type(TreeSpec)):  # type: ignore[misc]
910    def __instancecheck__(self, instance: object) -> bool:
911        return isinstance(instance, TreeSpec) and instance.is_leaf()
912
913
914class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
915    def __new__(cls) -> "LeafSpec":
916        return optree.treespec_leaf(none_is_leaf=True)  # type: ignore[return-value]
917
918
919def tree_flatten_with_path(
920    tree: PyTree,
921    is_leaf: Optional[Callable[[PyTree], bool]] = None,
922) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
923    """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
924
925    Args:
926        tree: a pytree to flatten. If it contains a custom type, that type must be
927            registered with an appropriate `tree_flatten_with_path_fn` when registered
928            with :func:`register_pytree_node`.
929        is_leaf: An extra leaf predicate function that will be called at each
930            flattening step. The function should have a single argument with signature
931            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
932            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
933            leaf or not. If the function is not specified, the default pytree registry will be used.
934    Returns:
935        A tuple where the first element is a list of (key path, leaf) pairs, and the
936        second element is a :class:`TreeSpec` representing the structure of the flattened
937        tree.
938    """
939    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
940
941
942def tree_leaves_with_path(
943    tree: PyTree,
944    is_leaf: Optional[Callable[[PyTree], bool]] = None,
945) -> List[Tuple[KeyPath, Any]]:
946    """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
947
948    Args:
949        tree: a pytree. If it contains a custom type, that type must be
950            registered with an appropriate `tree_flatten_with_path_fn` when registered
951            with :func:`register_pytree_node`.
952        is_leaf: An extra leaf predicate function that will be called at each
953            flattening step. The function should have a single argument with signature
954            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
955            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
956            leaf or not. If the function is not specified, the default pytree registry will be used.
957    Returns:
958        A list of (key path, leaf) pairs.
959    """
960    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
961
962
963def tree_map_with_path(
964    func: Callable[..., Any],
965    tree: PyTree,
966    *rests: PyTree,
967    is_leaf: Optional[Callable[[PyTree], bool]] = None,
968) -> PyTree:
969    """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
970
971    Args:
972        func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
973            corresponding leaves of the pytrees. The first positional argument
974            to ``func`` is the key path of the leaf in question. The second
975            positional argument is the value of the leaf.
976        tree: A pytree to be mapped over, with each leaf providing the first positional
977            argument to function ``func``.
978        rests: A tuple of pytrees, each of which has the same structure as
979            ``tree`` or has ``tree`` as a prefix.
980        is_leaf: An extra leaf predicate function that will be called at each
981            flattening step. The function should have a single argument with signature
982            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
983            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
984            leaf or not. If the function is not specified, the default pytree registry will be used.
985
986    Returns
987        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
988        ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
989        corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
990        ``xs`` is the tuple of values at corresponding nodes in ``rests``.
991    """
992    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
993
994
995def keystr(kp: KeyPath) -> str:
996    """Given a key path, return a pretty-printed representation."""
997    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
998
999
1000def key_get(obj: Any, kp: KeyPath) -> Any:
1001    """Given an object and a key path, return the value at the key path."""
1002    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
1003
1004
1005_pytree._cxx_pytree_imported = True
1006for args, kwargs in _pytree._cxx_pytree_pending_imports:
1007    _private_register_pytree_node(*args, **kwargs)
1008