xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/_traverse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Meta Platforms, Inc. and affiliates
2from typing import (
3    Callable,
4    cast,
5    Collection,
6    List,
7    Mapping,
8    MutableMapping,
9    Optional,
10    Tuple,
11    TypeVar,
12    Union,
13)
14
15import torch
16from torch.distributed._shard.sharded_tensor.api import ShardedTensor
17from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
18from torch.distributed.tensor import DTensor
19
20
21PATH_ITEM = Union[str, int]
22OBJ_PATH = Tuple[PATH_ITEM, ...]
23T = TypeVar("T")
24
25STATE_DICT_ITEM = object
26CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
27
28__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
29
30
31def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
32    return isinstance(value, torch.Tensor)
33
34
35# TODO: update docstring for traverse.py
36def traverse_state_dict(
37    state_dict: STATE_DICT_TYPE,
38    visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
39    keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
40) -> None:
41    """
42    Invoke ``visitor`` for each value recursively in ``state_dict``.
43    Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
44    ``visitor`` will only be applied to elements in a list or a tuple, if the
45    container contains tensors or mappings.
46    """
47
48    def _is_terminal(value: STATE_DICT_ITEM) -> bool:
49        values: Collection[STATE_DICT_ITEM]
50        if isinstance(value, Mapping):
51            return False
52        elif isinstance(value, list):
53            values = value
54        else:
55            return True
56
57        for entry in values:
58            if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
59                return False
60            if keep_traversing is not None and keep_traversing(entry):
61                return False
62        return True
63
64    def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
65        if isinstance(value, Mapping):
66            for k, v in value.items():
67                _traverse_obj(path + (str(k),), v)
68        elif _is_terminal(value):
69            visitor(path, value)
70        elif isinstance(value, (list, tuple)):
71            for i, v in enumerate(value):
72                _traverse_obj(path + (i,), v)
73
74    for key, value in state_dict.items():
75        _traverse_obj((str(key),), value)
76
77
78def traverse_state_dict_v_2_3(
79    state_dict: STATE_DICT_TYPE,
80    visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
81    keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
82) -> None:
83    """
84    Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
85    to false for all elements.
86    By default, all collections with at least one ``torch.Tensor`` element are traversed.
87    Visitor takes a path argument that is a tuple of the keys used to reach it.
88    """
89
90    # a value is terminal if it has no other containers values inside it
91    def _is_terminal(value: STATE_DICT_ITEM) -> bool:
92        values: Collection[STATE_DICT_ITEM]
93        if isinstance(value, Mapping):
94            values = value.values()
95        elif isinstance(value, list):
96            values = value
97        else:
98            return True
99
100        for entry in values:
101            if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
102                return False
103            if keep_traversing is not None and keep_traversing(entry):
104                return False
105        return True
106
107    def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
108        if _is_terminal(value):
109            visitor(path, value)
110        elif isinstance(value, Mapping):
111            for k, v in value.items():
112                _traverse_obj(path + (str(k),), v)
113        elif isinstance(value, list):
114            for i, v in enumerate(value):
115                _traverse_obj(path + (i,), v)
116
117    for key, value in state_dict.items():
118        _traverse_obj((str(key),), value)
119
120
121def set_element(
122    root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
123) -> None:
124    """Set ``value`` in ``root_dict`` along the ``path`` object path."""
125    cur_container = cast(CONTAINER_TYPE, root_dict)
126
127    def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None:
128        while len(lst) <= idx:
129            lst.append(None)
130
131    for i in range(1, len(path)):
132        prev_key = path[i - 1]
133        key = path[i]
134        def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
135
136        if isinstance(cur_container, Mapping):
137            cur_container = cast(
138                CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
139            )
140        else:
141            extend_list(cur_container, prev_key)
142            if cur_container[prev_key] is None:
143                cur_container[prev_key] = def_val
144            cur_container = cur_container[prev_key]
145
146    key = path[-1]
147    if type(key) == int:
148        extend_list(cast(List[STATE_DICT_ITEM], cur_container), key)
149
150    cur_container[key] = value
151
152
153def get_element(
154    root_dict: STATE_DICT_TYPE,
155    path: OBJ_PATH,
156    default_value: Optional[T] = None,
157) -> Optional[T]:
158    """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
159    cur_value = cast(CONTAINER_TYPE, root_dict)
160    for part in path:
161        if type(part) is int:
162            if not isinstance(cur_value, list) or len(cur_value) < part:
163                return default_value
164        elif not isinstance(cur_value, Mapping) or part not in cur_value:
165            return default_value
166
167        cur_value = cast(CONTAINER_TYPE, cur_value[part])
168    return cast(Optional[T], cur_value)
169
170
171def _print_nested(
172    value: STATE_DICT_ITEM,
173    prefix: str = "",
174    print_fun: Callable[[str], None] = print,
175) -> None:
176    if type(value) is ShardedTensor:
177        print_fun(f"{prefix} ShardedTensor size: {value.size()}")
178        for shard in value.local_shards():
179            _print_nested(
180                shard.tensor,
181                f"{shard.metadata.shard_offsets} ",
182                print_fun=print_fun,
183            )
184    elif type(value) is (DTensor):
185        print_fun(f"{prefix} DistributedTensor size: {value.size()}")
186        # TODO: add local offset for _local_tensor in print_nested.
187        _print_nested(
188            value._local_tensor,
189            print_fun=print_fun,
190        )
191    elif isinstance(value, torch.Tensor):
192        print_fun(f"{prefix} Tensor size: {value.size()}")
193    else:
194        print_fun(f"{prefix} Type: {type(value)}")
195
196
197def print_tensor(
198    path: OBJ_PATH,
199    value: STATE_DICT_ITEM,
200    print_fun: Callable[[str], None] = print,
201) -> None:
202    """
203    Use this callback with traverse_state_dict to print its content.
204
205    By default the content is printed using the builtin ``print`` but this can
206    be change by passing a different ``print_fun` callable.
207    """
208    _print_nested(value, prefix=str(path), print_fun=print_fun)
209