1# Copyright (c) Meta Platforms, Inc. and affiliates 2from typing import ( 3 Callable, 4 cast, 5 Collection, 6 List, 7 Mapping, 8 MutableMapping, 9 Optional, 10 Tuple, 11 TypeVar, 12 Union, 13) 14 15import torch 16from torch.distributed._shard.sharded_tensor.api import ShardedTensor 17from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE 18from torch.distributed.tensor import DTensor 19 20 21PATH_ITEM = Union[str, int] 22OBJ_PATH = Tuple[PATH_ITEM, ...] 23T = TypeVar("T") 24 25STATE_DICT_ITEM = object 26CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] 27 28__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] 29 30 31def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: 32 return isinstance(value, torch.Tensor) 33 34 35# TODO: update docstring for traverse.py 36def traverse_state_dict( 37 state_dict: STATE_DICT_TYPE, 38 visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], 39 keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, 40) -> None: 41 """ 42 Invoke ``visitor`` for each value recursively in ``state_dict``. 43 Mapping will be traversed and ``visitor`` will be applied to the leaf elements. 44 ``visitor`` will only be applied to elements in a list or a tuple, if the 45 container contains tensors or mappings. 46 """ 47 48 def _is_terminal(value: STATE_DICT_ITEM) -> bool: 49 values: Collection[STATE_DICT_ITEM] 50 if isinstance(value, Mapping): 51 return False 52 elif isinstance(value, list): 53 values = value 54 else: 55 return True 56 57 for entry in values: 58 if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): 59 return False 60 if keep_traversing is not None and keep_traversing(entry): 61 return False 62 return True 63 64 def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: 65 if isinstance(value, Mapping): 66 for k, v in value.items(): 67 _traverse_obj(path + (str(k),), v) 68 elif _is_terminal(value): 69 visitor(path, value) 70 elif isinstance(value, (list, tuple)): 71 for i, v in enumerate(value): 72 _traverse_obj(path + (i,), v) 73 74 for key, value in state_dict.items(): 75 _traverse_obj((str(key),), value) 76 77 78def traverse_state_dict_v_2_3( 79 state_dict: STATE_DICT_TYPE, 80 visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], 81 keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, 82) -> None: 83 """ 84 Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates 85 to false for all elements. 86 By default, all collections with at least one ``torch.Tensor`` element are traversed. 87 Visitor takes a path argument that is a tuple of the keys used to reach it. 88 """ 89 90 # a value is terminal if it has no other containers values inside it 91 def _is_terminal(value: STATE_DICT_ITEM) -> bool: 92 values: Collection[STATE_DICT_ITEM] 93 if isinstance(value, Mapping): 94 values = value.values() 95 elif isinstance(value, list): 96 values = value 97 else: 98 return True 99 100 for entry in values: 101 if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): 102 return False 103 if keep_traversing is not None and keep_traversing(entry): 104 return False 105 return True 106 107 def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: 108 if _is_terminal(value): 109 visitor(path, value) 110 elif isinstance(value, Mapping): 111 for k, v in value.items(): 112 _traverse_obj(path + (str(k),), v) 113 elif isinstance(value, list): 114 for i, v in enumerate(value): 115 _traverse_obj(path + (i,), v) 116 117 for key, value in state_dict.items(): 118 _traverse_obj((str(key),), value) 119 120 121def set_element( 122 root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM 123) -> None: 124 """Set ``value`` in ``root_dict`` along the ``path`` object path.""" 125 cur_container = cast(CONTAINER_TYPE, root_dict) 126 127 def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: 128 while len(lst) <= idx: 129 lst.append(None) 130 131 for i in range(1, len(path)): 132 prev_key = path[i - 1] 133 key = path[i] 134 def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) 135 136 if isinstance(cur_container, Mapping): 137 cur_container = cast( 138 CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) 139 ) 140 else: 141 extend_list(cur_container, prev_key) 142 if cur_container[prev_key] is None: 143 cur_container[prev_key] = def_val 144 cur_container = cur_container[prev_key] 145 146 key = path[-1] 147 if type(key) == int: 148 extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) 149 150 cur_container[key] = value 151 152 153def get_element( 154 root_dict: STATE_DICT_TYPE, 155 path: OBJ_PATH, 156 default_value: Optional[T] = None, 157) -> Optional[T]: 158 """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found.""" 159 cur_value = cast(CONTAINER_TYPE, root_dict) 160 for part in path: 161 if type(part) is int: 162 if not isinstance(cur_value, list) or len(cur_value) < part: 163 return default_value 164 elif not isinstance(cur_value, Mapping) or part not in cur_value: 165 return default_value 166 167 cur_value = cast(CONTAINER_TYPE, cur_value[part]) 168 return cast(Optional[T], cur_value) 169 170 171def _print_nested( 172 value: STATE_DICT_ITEM, 173 prefix: str = "", 174 print_fun: Callable[[str], None] = print, 175) -> None: 176 if type(value) is ShardedTensor: 177 print_fun(f"{prefix} ShardedTensor size: {value.size()}") 178 for shard in value.local_shards(): 179 _print_nested( 180 shard.tensor, 181 f"{shard.metadata.shard_offsets} ", 182 print_fun=print_fun, 183 ) 184 elif type(value) is (DTensor): 185 print_fun(f"{prefix} DistributedTensor size: {value.size()}") 186 # TODO: add local offset for _local_tensor in print_nested. 187 _print_nested( 188 value._local_tensor, 189 print_fun=print_fun, 190 ) 191 elif isinstance(value, torch.Tensor): 192 print_fun(f"{prefix} Tensor size: {value.size()}") 193 else: 194 print_fun(f"{prefix} Type: {type(value)}") 195 196 197def print_tensor( 198 path: OBJ_PATH, 199 value: STATE_DICT_ITEM, 200 print_fun: Callable[[str], None] = print, 201) -> None: 202 """ 203 Use this callback with traverse_state_dict to print its content. 204 205 By default the content is printed using the builtin ``print`` but this can 206 be change by passing a different ``print_fun` callable. 207 """ 208 _print_nested(value, prefix=str(path), print_fun=print_fun) 209