1""" 2Contains utility functions for working with nested python data structures. 3 4A *pytree* is Python nested data structure. It is a tree in the sense that 5nodes are Python collections (e.g., list, tuple, dict) and the leaves are 6Python values. Furthermore, a pytree should not contain reference cycles. 7 8pytrees are useful for working with nested collections of Tensors. For example, 9one can use `tree_map` to map a function over all Tensors inside some nested 10collection of Tensors and `tree_leaves` to get a flat list of all Tensors 11inside some nested collection. pytrees are helpful for implementing nested 12collection support for PyTorch APIs. 13""" 14 15import functools 16import sys 17import types 18from typing import ( 19 Any, 20 Callable, 21 Iterable, 22 List, 23 Optional, 24 overload, 25 Tuple, 26 Type, 27 TypeVar, 28 Union, 29) 30from typing_extensions import deprecated 31 32import optree 33from optree import PyTreeSpec # direct import for type annotations 34 35import torch.utils._pytree as _pytree 36from torch.utils._pytree import KeyEntry 37 38 39__all__ = [ 40 "PyTree", 41 "Context", 42 "FlattenFunc", 43 "UnflattenFunc", 44 "DumpableContext", 45 "ToDumpableContextFn", 46 "FromDumpableContextFn", 47 "TreeSpec", 48 "LeafSpec", 49 "keystr", 50 "key_get", 51 "register_pytree_node", 52 "tree_flatten", 53 "tree_flatten_with_path", 54 "tree_unflatten", 55 "tree_iter", 56 "tree_leaves", 57 "tree_leaves_with_path", 58 "tree_structure", 59 "tree_map", 60 "tree_map_with_path", 61 "tree_map_", 62 "tree_map_only", 63 "tree_map_only_", 64 "tree_all", 65 "tree_any", 66 "tree_all_only", 67 "tree_any_only", 68 "treespec_dumps", 69 "treespec_loads", 70 "treespec_pprint", 71] 72 73 74T = TypeVar("T") 75S = TypeVar("S") 76U = TypeVar("U") 77R = TypeVar("R") 78 79 80Context = Any 81PyTree = Any 82TreeSpec = PyTreeSpec 83FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] 84UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] 85OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree] 86DumpableContext = Any # Any json dumpable text 87ToDumpableContextFn = Callable[[Context], DumpableContext] 88FromDumpableContextFn = Callable[[DumpableContext], Context] 89KeyPath = Tuple[KeyEntry, ...] 90FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] 91 92 93def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc: 94 @functools.wraps(func) 95 def wrapped(*args: Any, **kwargs: Any) -> Any: 96 return func(*reversed(args), **kwargs) 97 98 return wrapped 99 100 101def register_pytree_node( 102 cls: Type[Any], 103 flatten_fn: FlattenFunc, 104 unflatten_fn: UnflattenFunc, 105 *, 106 serialized_type_name: Optional[str] = None, 107 to_dumpable_context: Optional[ToDumpableContextFn] = None, 108 from_dumpable_context: Optional[FromDumpableContextFn] = None, 109 flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, 110) -> None: 111 """Register a container-like type as pytree node. 112 113 Args: 114 cls (type): A Python type to treat as an internal pytree node. 115 flatten_fn (callable): A function to be used during flattening, taking an instance of 116 ``cls`` and returning a pair, with (1) an iterable for the children to be flattened 117 recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be 118 passed to the ``unflatten_fn``. 119 unflatten_fn (callable): A function taking two arguments: the auxiliary data that was 120 returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. 121 The function should return an instance of ``cls``. 122 serialized_type_name (str, optional): A keyword argument used to specify the fully 123 qualified name used when serializing the tree spec. 124 to_dumpable_context (callable, optional): An optional keyword argument to custom specify how 125 to convert the context of the pytree to a custom json dumpable representation. This is 126 used for json serialization, which is being used in :mod:`torch.export` right now. 127 from_dumpable_context (callable, optional): An optional keyword argument to custom specify 128 how to convert the custom json dumpable representation of the context back to the 129 original context. This is used for json deserialization, which is being used in 130 :mod:`torch.export` right now. 131 132 Example:: 133 134 >>> # xdoctest: +SKIP 135 >>> # Registry a Python type with lambda functions 136 >>> register_pytree_node( 137 ... set, 138 ... lambda s: (sorted(s), None, None), 139 ... lambda children, _: set(children), 140 ... ) 141 """ 142 if flatten_with_keys_fn is not None: 143 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 144 145 _private_register_pytree_node( 146 cls, 147 flatten_fn, 148 unflatten_fn, 149 serialized_type_name=serialized_type_name, 150 to_dumpable_context=to_dumpable_context, 151 from_dumpable_context=from_dumpable_context, 152 ) 153 154 from . import _pytree as python 155 156 python._private_register_pytree_node( 157 cls, 158 flatten_fn, 159 unflatten_fn, 160 serialized_type_name=serialized_type_name, 161 to_dumpable_context=to_dumpable_context, 162 from_dumpable_context=from_dumpable_context, 163 ) 164 165 166@deprecated( 167 "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. " 168 "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.", 169 category=FutureWarning, 170) 171def _register_pytree_node( 172 cls: Type[Any], 173 flatten_fn: FlattenFunc, 174 unflatten_fn: UnflattenFunc, 175 *, 176 serialized_type_name: Optional[str] = None, 177 to_dumpable_context: Optional[ToDumpableContextFn] = None, 178 from_dumpable_context: Optional[FromDumpableContextFn] = None, 179) -> None: 180 """Register a container-like type as pytree node for the C++ pytree only. 181 182 The ``namespace`` argument is used to avoid collisions that occur when different libraries 183 register the same Python type with different behaviors. It is recommended to add a unique prefix 184 to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify 185 the same class in different namespaces for different use cases. 186 187 .. warning:: 188 For safety reasons, a ``namespace`` must be specified while registering a custom type. It is 189 used to isolate the behavior of flattening and unflattening a pytree node type. This is to 190 prevent accidental collisions between different libraries that may register the same type. 191 192 Args: 193 cls (type): A Python type to treat as an internal pytree node. 194 flatten_fn (callable): A function to be used during flattening, taking an instance of 195 ``cls`` and returning a pair, with (1) an iterable for the children to be flattened 196 recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be 197 passed to the ``unflatten_fn``. 198 unflatten_fn (callable): A function taking two arguments: the auxiliary data that was 199 returned by ``flatten_fn`` and stored in the treespec, and the unflattened children. 200 The function should return an instance of ``cls``. 201 serialized_type_name (str, optional): A keyword argument used to specify the fully 202 qualified name used when serializing the tree spec. 203 to_dumpable_context (callable, optional): An optional keyword argument to custom specify how 204 to convert the context of the pytree to a custom json dumpable representation. This is 205 used for json serialization, which is being used in :mod:`torch.export` right now. 206 from_dumpable_context (callable, optional): An optional keyword argument to custom specify 207 how to convert the custom json dumpable representation of the context back to the 208 original context. This is used for json deserialization, which is being used in 209 :mod:`torch.export` right now. 210 """ 211 212 _private_register_pytree_node( 213 cls, 214 flatten_fn, 215 unflatten_fn, 216 serialized_type_name=serialized_type_name, 217 to_dumpable_context=to_dumpable_context, 218 from_dumpable_context=from_dumpable_context, 219 ) 220 221 222def _private_register_pytree_node( 223 cls: Type[Any], 224 flatten_fn: FlattenFunc, 225 unflatten_fn: UnflattenFunc, 226 *, 227 serialized_type_name: Optional[str] = None, 228 to_dumpable_context: Optional[ToDumpableContextFn] = None, 229 from_dumpable_context: Optional[FromDumpableContextFn] = None, 230) -> None: 231 """This is an internal function that is used to register a pytree node type 232 for the C++ pytree only. End-users should use :func:`register_pytree_node` 233 instead. 234 """ 235 # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support 236 # PyStructSequence types 237 if not optree.is_structseq_class(cls): 238 optree.register_pytree_node( 239 cls, 240 flatten_fn, 241 _reverse_args(unflatten_fn), 242 namespace="torch", 243 ) 244 245 246def tree_flatten( 247 tree: PyTree, 248 is_leaf: Optional[Callable[[PyTree], bool]] = None, 249) -> Tuple[List[Any], TreeSpec]: 250 """Flatten a pytree. 251 252 See also :func:`tree_unflatten`. 253 254 The flattening order (i.e., the order of elements in the output list) is deterministic, 255 corresponding to a left-to-right depth-first tree traversal. 256 257 >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} 258 >>> tree_flatten(tree) 259 ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) 260 >>> tree_flatten(1) 261 ([1], PyTreeSpec(*, NoneIsLeaf)) 262 >>> tree_flatten(None) 263 ([None], PyTreeSpec(*, NoneIsLeaf)) 264 265 For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is 266 dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` 267 if you want to keep the keys in the insertion order. 268 269 >>> from collections import OrderedDict 270 >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) 271 >>> tree_flatten(tree) 272 ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf)) 273 274 Args: 275 tree (pytree): A pytree to flatten. 276 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 277 flattening step. The function should have a single argument with signature 278 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 279 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 280 leaf or not. If the function is not specified, the default pytree registry will be used. 281 282 Returns: 283 A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the 284 second element is a treespec representing the structure of the pytree. 285 """ 286 return optree.tree_flatten( # type: ignore[return-value] 287 tree, 288 is_leaf=is_leaf, 289 none_is_leaf=True, 290 namespace="torch", 291 ) 292 293 294def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: 295 """Reconstruct a pytree from the treespec and the leaves. 296 297 The inverse of :func:`tree_flatten`. 298 299 >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} 300 >>> leaves, treespec = tree_flatten(tree) 301 >>> tree == tree_unflatten(leaves, treespec) 302 True 303 304 Args: 305 leaves (iterable): The list of leaves to use for reconstruction. The list must match the 306 number of leaves of the treespec. 307 treespec (TreeSpec): The treespec to reconstruct. 308 309 Returns: 310 The reconstructed pytree, containing the ``leaves`` placed in the structure described by 311 ``treespec``. 312 """ 313 if not isinstance(treespec, TreeSpec): 314 raise TypeError( 315 f"tree_unflatten(values, spec): Expected `spec` to be instance of " 316 f"TreeSpec but got item of type {type(treespec)}." 317 ) 318 return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] 319 320 321def tree_iter( 322 tree: PyTree, 323 is_leaf: Optional[Callable[[PyTree], bool]] = None, 324) -> Iterable[Any]: 325 """Get an iterator over the leaves of a pytree. 326 327 See also :func:`tree_flatten`. 328 329 >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} 330 >>> list(tree_iter(tree)) 331 [1, 2, 3, 4, None, 5] 332 >>> list(tree_iter(1)) 333 [1] 334 >>> list(tree_iter(None)) 335 [None] 336 337 Args: 338 tree (pytree): A pytree to flatten. 339 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 340 flattening step. The function should have a single argument with signature 341 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 342 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 343 leaf or not. If the function is not specified, the default pytree registry will be used. 344 345 Returns: 346 An iterator over the leaf values. 347 """ 348 return optree.tree_iter( 349 tree, 350 is_leaf=is_leaf, 351 none_is_leaf=True, 352 namespace="torch", 353 ) 354 355 356def tree_leaves( 357 tree: PyTree, 358 is_leaf: Optional[Callable[[PyTree], bool]] = None, 359) -> List[Any]: 360 """Get the leaves of a pytree. 361 362 See also :func:`tree_flatten`. 363 364 >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} 365 >>> tree_leaves(tree) 366 [1, 2, 3, 4, None, 5] 367 >>> tree_leaves(1) 368 [1] 369 >>> tree_leaves(None) 370 [None] 371 372 Args: 373 tree (pytree): A pytree to flatten. 374 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 375 flattening step. The function should have a single argument with signature 376 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 377 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 378 leaf or not. If the function is not specified, the default pytree registry will be used. 379 380 Returns: 381 A list of leaf values. 382 """ 383 return optree.tree_leaves( 384 tree, 385 is_leaf=is_leaf, 386 none_is_leaf=True, 387 namespace="torch", 388 ) 389 390 391def tree_structure( 392 tree: PyTree, 393 is_leaf: Optional[Callable[[PyTree], bool]] = None, 394) -> TreeSpec: 395 """Get the treespec for a pytree. 396 397 See also :func:`tree_flatten`. 398 399 >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} 400 >>> tree_structure(tree) 401 PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) 402 >>> tree_structure(1) 403 PyTreeSpec(*, NoneIsLeaf) 404 >>> tree_structure(None) 405 PyTreeSpec(*, NoneIsLeaf) 406 407 Args: 408 tree (pytree): A pytree to flatten. 409 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 410 flattening step. The function should have a single argument with signature 411 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 412 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 413 leaf or not. If the function is not specified, the default pytree registry will be used. 414 415 Returns: 416 A treespec object representing the structure of the pytree. 417 """ 418 return optree.tree_structure( # type: ignore[return-value] 419 tree, 420 is_leaf=is_leaf, 421 none_is_leaf=True, 422 namespace="torch", 423 ) 424 425 426def tree_map( 427 func: Callable[..., Any], 428 tree: PyTree, 429 *rests: PyTree, 430 is_leaf: Optional[Callable[[PyTree], bool]] = None, 431) -> PyTree: 432 """Map a multi-input function over pytree args to produce a new pytree. 433 434 See also :func:`tree_map_`. 435 436 >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) 437 {'x': 8, 'y': (43, 65)} 438 >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) 439 {'x': False, 'y': (False, False), 'z': True} 440 441 If multiple inputs are given, the structure of the tree is taken from the first input; 442 subsequent inputs need only have ``tree`` as a prefix: 443 444 >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) 445 [[5, 7, 9], [6, 1, 2]] 446 447 Args: 448 func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the 449 corresponding leaves of the pytrees. 450 tree (pytree): A pytree to be mapped over, with each leaf providing the first positional 451 argument to function ``func``. 452 rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as 453 ``tree`` or has ``tree`` as a prefix. 454 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 455 flattening step. The function should have a single argument with signature 456 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 457 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 458 leaf or not. If the function is not specified, the default pytree registry will be used. 459 460 Returns: 461 A new pytree with the same structure as ``tree`` but with the value at each leaf given by 462 ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` 463 is the tuple of values at corresponding nodes in ``rests``. 464 """ 465 return optree.tree_map( 466 func, 467 tree, 468 *rests, 469 is_leaf=is_leaf, 470 none_is_leaf=True, 471 namespace="torch", 472 ) 473 474 475def tree_map_( 476 func: Callable[..., Any], 477 tree: PyTree, 478 *rests: PyTree, 479 is_leaf: Optional[Callable[[PyTree], bool]] = None, 480) -> PyTree: 481 """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. 482 483 See also :func:`tree_map`. 484 485 Args: 486 func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the 487 corresponding leaves of the pytrees. 488 tree (pytree): A pytree to be mapped over, with each leaf providing the first positional 489 argument to function ``func``. 490 rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as 491 ``tree`` or has ``tree`` as a prefix. 492 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 493 flattening step. The function should have a single argument with signature 494 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 495 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 496 leaf or not. If the function is not specified, the default pytree registry will be used. 497 498 Returns: 499 The original ``tree`` with the value at each leaf is given by the side-effect of function 500 ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf 501 in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. 502 """ 503 return optree.tree_map_( 504 func, 505 tree, 506 *rests, 507 is_leaf=is_leaf, 508 none_is_leaf=True, 509 namespace="torch", 510 ) 511 512 513Type2 = Tuple[Type[T], Type[S]] 514Type3 = Tuple[Type[T], Type[S], Type[U]] 515if sys.version_info >= (3, 10): 516 TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] 517else: 518 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] 519 520Fn2 = Callable[[Union[T, S]], R] 521Fn3 = Callable[[Union[T, S, U]], R] 522Fn = Callable[[T], R] 523FnAny = Callable[[Any], R] 524 525MapOnlyFn = Callable[[T], Callable[[Any], Any]] 526 527 528# These specializations help with type inference on the lambda passed to this 529# function 530@overload 531def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: 532 ... 533 534 535@overload 536def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: 537 ... 538 539 540@overload 541def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: 542 ... 543 544 545# This specialization is needed for the implementations below that call 546@overload 547def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: 548 ... 549 550 551@overload 552def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: 553 ... 554 555 556def map_only( 557 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] 558) -> MapOnlyFn[FnAny[Any]]: 559 """ 560 Suppose you are writing a tree_map over tensors, leaving everything 561 else unchanged. Ordinarily you would have to write: 562 563 def go(t): 564 if isinstance(t, Tensor): 565 return ... 566 else: 567 return t 568 569 With this function, you only need to write: 570 571 @map_only(Tensor) 572 def go(t): 573 return ... 574 575 You can also directly use 'tree_map_only' 576 """ 577 if isinstance(__type_or_types_or_pred, (type, tuple)) or ( 578 sys.version_info >= (3, 10) 579 and isinstance(__type_or_types_or_pred, types.UnionType) 580 ): 581 582 def pred(x: Any) -> bool: 583 return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] 584 585 elif callable(__type_or_types_or_pred): 586 pred = __type_or_types_or_pred # type: ignore[assignment] 587 else: 588 raise TypeError("Argument must be a type, a tuple of types, or a callable.") 589 590 def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: 591 @functools.wraps(func) 592 def wrapped(x: T) -> Any: 593 if pred(x): 594 return func(x) 595 return x 596 597 return wrapped 598 599 return wrapper 600 601 602@overload 603def tree_map_only( 604 __type_or_types_or_pred: Type[T], 605 func: Fn[T, Any], 606 tree: PyTree, 607 is_leaf: Optional[Callable[[PyTree], bool]] = None, 608) -> PyTree: 609 ... 610 611 612@overload 613def tree_map_only( 614 __type_or_types_or_pred: Type2[T, S], 615 func: Fn2[T, S, Any], 616 tree: PyTree, 617 is_leaf: Optional[Callable[[PyTree], bool]] = None, 618) -> PyTree: 619 ... 620 621 622@overload 623def tree_map_only( 624 __type_or_types_or_pred: Type3[T, S, U], 625 func: Fn3[T, S, U, Any], 626 tree: PyTree, 627 is_leaf: Optional[Callable[[PyTree], bool]] = None, 628) -> PyTree: 629 ... 630 631 632@overload 633def tree_map_only( 634 __type_or_types_or_pred: Callable[[Any], bool], 635 func: FnAny[Any], 636 tree: PyTree, 637 is_leaf: Optional[Callable[[PyTree], bool]] = None, 638) -> PyTree: 639 ... 640 641 642def tree_map_only( 643 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], 644 func: FnAny[Any], 645 tree: PyTree, 646 is_leaf: Optional[Callable[[PyTree], bool]] = None, 647) -> PyTree: 648 return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) 649 650 651@overload 652def tree_map_only_( 653 __type_or_types_or_pred: Type[T], 654 func: Fn[T, Any], 655 tree: PyTree, 656 is_leaf: Optional[Callable[[PyTree], bool]] = None, 657) -> PyTree: 658 ... 659 660 661@overload 662def tree_map_only_( 663 __type_or_types_or_pred: Type2[T, S], 664 func: Fn2[T, S, Any], 665 tree: PyTree, 666 is_leaf: Optional[Callable[[PyTree], bool]] = None, 667) -> PyTree: 668 ... 669 670 671@overload 672def tree_map_only_( 673 __type_or_types_or_pred: Type3[T, S, U], 674 func: Fn3[T, S, U, Any], 675 tree: PyTree, 676 is_leaf: Optional[Callable[[PyTree], bool]] = None, 677) -> PyTree: 678 ... 679 680 681@overload 682def tree_map_only_( 683 __type_or_types_or_pred: Callable[[Any], bool], 684 func: FnAny[Any], 685 tree: PyTree, 686 is_leaf: Optional[Callable[[PyTree], bool]] = None, 687) -> PyTree: 688 ... 689 690 691def tree_map_only_( 692 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], 693 func: FnAny[Any], 694 tree: PyTree, 695 is_leaf: Optional[Callable[[PyTree], bool]] = None, 696) -> PyTree: 697 return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) 698 699 700def tree_all( 701 pred: Callable[[Any], bool], 702 tree: PyTree, 703 is_leaf: Optional[Callable[[PyTree], bool]] = None, 704) -> bool: 705 flat_args = tree_iter(tree, is_leaf=is_leaf) 706 return all(map(pred, flat_args)) 707 708 709def tree_any( 710 pred: Callable[[Any], bool], 711 tree: PyTree, 712 is_leaf: Optional[Callable[[PyTree], bool]] = None, 713) -> bool: 714 flat_args = tree_iter(tree, is_leaf=is_leaf) 715 return any(map(pred, flat_args)) 716 717 718@overload 719def tree_all_only( 720 __type_or_types: Type[T], 721 pred: Fn[T, bool], 722 tree: PyTree, 723 is_leaf: Optional[Callable[[PyTree], bool]] = None, 724) -> bool: 725 ... 726 727 728@overload 729def tree_all_only( 730 __type_or_types: Type2[T, S], 731 pred: Fn2[T, S, bool], 732 tree: PyTree, 733 is_leaf: Optional[Callable[[PyTree], bool]] = None, 734) -> bool: 735 ... 736 737 738@overload 739def tree_all_only( 740 __type_or_types: Type3[T, S, U], 741 pred: Fn3[T, S, U, bool], 742 tree: PyTree, 743 is_leaf: Optional[Callable[[PyTree], bool]] = None, 744) -> bool: 745 ... 746 747 748def tree_all_only( 749 __type_or_types: TypeAny, 750 pred: FnAny[bool], 751 tree: PyTree, 752 is_leaf: Optional[Callable[[PyTree], bool]] = None, 753) -> bool: 754 flat_args = tree_iter(tree, is_leaf=is_leaf) 755 return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) 756 757 758@overload 759def tree_any_only( 760 __type_or_types: Type[T], 761 pred: Fn[T, bool], 762 tree: PyTree, 763 is_leaf: Optional[Callable[[PyTree], bool]] = None, 764) -> bool: 765 ... 766 767 768@overload 769def tree_any_only( 770 __type_or_types: Type2[T, S], 771 pred: Fn2[T, S, bool], 772 tree: PyTree, 773 is_leaf: Optional[Callable[[PyTree], bool]] = None, 774) -> bool: 775 ... 776 777 778@overload 779def tree_any_only( 780 __type_or_types: Type3[T, S, U], 781 pred: Fn3[T, S, U, bool], 782 tree: PyTree, 783 is_leaf: Optional[Callable[[PyTree], bool]] = None, 784) -> bool: 785 ... 786 787 788def tree_any_only( 789 __type_or_types: TypeAny, 790 pred: FnAny[bool], 791 tree: PyTree, 792 is_leaf: Optional[Callable[[PyTree], bool]] = None, 793) -> bool: 794 flat_args = tree_iter(tree, is_leaf=is_leaf) 795 return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) 796 797 798def broadcast_prefix( 799 prefix_tree: PyTree, 800 full_tree: PyTree, 801 is_leaf: Optional[Callable[[PyTree], bool]] = None, 802) -> List[Any]: 803 """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. 804 805 If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be 806 constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**. 807 808 This function returns a list of leaves with the same size as ``full_tree``. The leaves are 809 replicated from ``prefix_tree``. The number of replicas is determined by the corresponding 810 subtree in ``full_tree``. 811 812 >>> broadcast_prefix(1, [1, 2, 3]) 813 [1, 1, 1] 814 >>> broadcast_prefix([1, 2, 3], [1, 2, 3]) 815 [1, 2, 3] 816 >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4]) 817 Traceback (most recent call last): 818 ... 819 ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4]. 820 >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)]) 821 [1, 2, 3, 3] 822 >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}]) 823 [1, 2, 3, 3, 3, 3] 824 825 Args: 826 prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``. 827 full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``. 828 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 829 flattening step. The function should have a single argument with signature 830 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 831 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 832 leaf or not. If the function is not specified, the default pytree registry will be used. 833 834 Returns: 835 A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``. 836 """ 837 return optree.broadcast_prefix( 838 prefix_tree, 839 full_tree, 840 is_leaf=is_leaf, 841 none_is_leaf=True, 842 namespace="torch", 843 ) 844 845 846# Broadcasts a pytree to the provided TreeSpec and returns the flattened 847# values. If this is not possible, then this function returns None. 848# 849# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), 850# would return [0, 0]. This is useful for part of the vmap implementation: 851# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be 852# broadcastable to the tree structure of `inputs` and we use 853# _broadcast_to_and_flatten to check this. 854def _broadcast_to_and_flatten( 855 tree: PyTree, 856 treespec: TreeSpec, 857 is_leaf: Optional[Callable[[PyTree], bool]] = None, 858) -> Optional[List[Any]]: 859 assert isinstance(treespec, TreeSpec) 860 full_tree = tree_unflatten([0] * treespec.num_leaves, treespec) 861 try: 862 return broadcast_prefix(tree, full_tree, is_leaf=is_leaf) 863 except ValueError: 864 return None 865 866 867def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: 868 """Serialize a treespec to a JSON string.""" 869 if not isinstance(treespec, TreeSpec): 870 raise TypeError( 871 f"treespec_dumps(spec): Expected `spec` to be instance of " 872 f"TreeSpec but got item of type {type(treespec)}." 873 ) 874 from ._pytree import ( 875 tree_structure as _tree_structure, 876 treespec_dumps as _treespec_dumps, 877 ) 878 879 orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec)) 880 return _treespec_dumps(orig_treespec, protocol=protocol) 881 882 883def treespec_loads(serialized: str) -> TreeSpec: 884 """Deserialize a treespec from a JSON string.""" 885 from ._pytree import ( 886 tree_unflatten as _tree_unflatten, 887 treespec_loads as _treespec_loads, 888 ) 889 890 orig_treespec = _treespec_loads(serialized) 891 dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec) 892 treespec = tree_structure(dummy_tree) 893 return treespec 894 895 896class _DummyLeaf: 897 def __repr__(self) -> str: 898 return "*" 899 900 901def treespec_pprint(treespec: TreeSpec) -> str: 902 dummy_tree = tree_unflatten( 903 [_DummyLeaf() for _ in range(treespec.num_leaves)], 904 treespec, 905 ) 906 return repr(dummy_tree) 907 908 909class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc] 910 def __instancecheck__(self, instance: object) -> bool: 911 return isinstance(instance, TreeSpec) and instance.is_leaf() 912 913 914class LeafSpec(TreeSpec, metaclass=LeafSpecMeta): 915 def __new__(cls) -> "LeafSpec": 916 return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value] 917 918 919def tree_flatten_with_path( 920 tree: PyTree, 921 is_leaf: Optional[Callable[[PyTree], bool]] = None, 922) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: 923 """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. 924 925 Args: 926 tree: a pytree to flatten. If it contains a custom type, that type must be 927 registered with an appropriate `tree_flatten_with_path_fn` when registered 928 with :func:`register_pytree_node`. 929 is_leaf: An extra leaf predicate function that will be called at each 930 flattening step. The function should have a single argument with signature 931 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 932 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 933 leaf or not. If the function is not specified, the default pytree registry will be used. 934 Returns: 935 A tuple where the first element is a list of (key path, leaf) pairs, and the 936 second element is a :class:`TreeSpec` representing the structure of the flattened 937 tree. 938 """ 939 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 940 941 942def tree_leaves_with_path( 943 tree: PyTree, 944 is_leaf: Optional[Callable[[PyTree], bool]] = None, 945) -> List[Tuple[KeyPath, Any]]: 946 """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. 947 948 Args: 949 tree: a pytree. If it contains a custom type, that type must be 950 registered with an appropriate `tree_flatten_with_path_fn` when registered 951 with :func:`register_pytree_node`. 952 is_leaf: An extra leaf predicate function that will be called at each 953 flattening step. The function should have a single argument with signature 954 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 955 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 956 leaf or not. If the function is not specified, the default pytree registry will be used. 957 Returns: 958 A list of (key path, leaf) pairs. 959 """ 960 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 961 962 963def tree_map_with_path( 964 func: Callable[..., Any], 965 tree: PyTree, 966 *rests: PyTree, 967 is_leaf: Optional[Callable[[PyTree], bool]] = None, 968) -> PyTree: 969 """Like :func:`tree_map`, but the provided callable takes an additional key path argument. 970 971 Args: 972 func: A function that takes ``2 + len(rests)`` arguments, to be applied at the 973 corresponding leaves of the pytrees. The first positional argument 974 to ``func`` is the key path of the leaf in question. The second 975 positional argument is the value of the leaf. 976 tree: A pytree to be mapped over, with each leaf providing the first positional 977 argument to function ``func``. 978 rests: A tuple of pytrees, each of which has the same structure as 979 ``tree`` or has ``tree`` as a prefix. 980 is_leaf: An extra leaf predicate function that will be called at each 981 flattening step. The function should have a single argument with signature 982 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 983 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 984 leaf or not. If the function is not specified, the default pytree registry will be used. 985 986 Returns 987 A new pytree with the same structure as ``tree`` but with the value at each leaf given by 988 ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the 989 corresponding leaf in ``tree``, ``x`` is the value at that leaf, and 990 ``xs`` is the tuple of values at corresponding nodes in ``rests``. 991 """ 992 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 993 994 995def keystr(kp: KeyPath) -> str: 996 """Given a key path, return a pretty-printed representation.""" 997 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 998 999 1000def key_get(obj: Any, kp: KeyPath) -> Any: 1001 """Given an object and a key path, return the value at the key path.""" 1002 raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.") 1003 1004 1005_pytree._cxx_pytree_imported = True 1006for args, kwargs in _pytree._cxx_pytree_pending_imports: 1007 _private_register_pytree_node(*args, **kwargs) 1008