1# mypy: allow-untyped-defs 2import copy 3import functools 4import logging 5import warnings 6from contextlib import ExitStack 7from dataclasses import dataclass, field 8from typing import ( 9 Any, 10 cast, 11 Dict, 12 Iterable, 13 Iterator, 14 List, 15 NamedTuple, 16 no_type_check, 17 Optional, 18 Sequence, 19 Set, 20 Tuple, 21 TYPE_CHECKING, 22 Union, 23) 24 25import torch 26import torch.distributed as dist 27import torch.distributed.fsdp._traversal_utils as traversal_utils 28import torch.nn as nn 29from torch.distributed._state_dict_utils import _gather_state_dict 30from torch.distributed.distributed_c10d import _get_pg_default_device 31from torch.distributed.fsdp._common_utils import ( 32 _apply_to_modules, 33 _FSDPState, 34 _get_module_fsdp_state_if_fully_sharded_module, 35 _get_param_to_fqns, 36 _module_handle, 37 _named_parameters_with_duplicates, 38 clean_tensor_name, 39) 40from torch.distributed.fsdp._debug_utils import SimpleProfiler 41from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle 42from torch.distributed.fsdp._fsdp_extensions import ( 43 _ext_chunk_dtensor, 44 _ext_chunk_tensor, 45) 46from torch.distributed.fsdp._runtime_utils import ( 47 _lazy_init, 48 _reset_flat_param_grad_info_if_needed, 49) 50from torch.distributed.fsdp.api import ( 51 ShardingStrategy, 52 StateDictSettings, 53 StateDictType, 54) 55from torch.distributed.tensor import DTensor, Replicate 56from torch.utils._pytree import tree_map_only 57 58 59if TYPE_CHECKING: 60 from torch.distributed._shard.sharded_tensor import ShardedTensor 61 62 63logger = logging.getLogger(__name__) 64 65 66@dataclass 67class FSDPParamInfo: 68 state: _FSDPState 69 handle: FlatParamHandle 70 param_indices: Dict[str, int] 71 param_requires_grad: List[bool] 72 73 74def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: 75 keys = sorted(dictionary.keys()) 76 for k in keys: 77 yield k, dictionary[k] 78 79 80@dataclass 81class _ConsolidatedOptimState: 82 """ 83 This holds the consolidated optimizer state on the target rank. Positive- 84 dimension tensor state is communicated across ranks, while zero-dimension 85 tensor state and non-tensor state is taken directly from the target rank. 86 87 PyTorch version 1.12 moved to using zero-dimension tensors for scalar 88 values, but user implemented optimizers may still use float (i.e. a 89 non-tensor). Thus, we support both and handle them identically. 90 91 Attributes: 92 tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension 93 tensor state name to the unsharded flat tensor representing the 94 state. 95 zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero- 96 dimension tensor state name to its value. 97 non_tensor_state (Dict[str, Any]): Mapping from non-tensor state 98 name to its value. 99 """ 100 101 tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) 102 zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) 103 non_tensor_state: Dict[str, Any] = field(default_factory=dict) 104 105 106class _PosDimTensorInfo(NamedTuple): 107 """ 108 Meatadata for positive-dimension tensors used internally for 109 :meth:`scatter_full_optim_state_dict`. 110 111 Attributes: 112 shape (torch.Size): Sharded tensor shape (which is equal to the 113 unsharded tensor shape if the tensor is optimizer state for a 114 non-FSDP parameter and is hence not sharded). 115 dtype (torch.dtype): Data type of the tensor. 116 """ 117 118 shape: torch.Size 119 dtype: torch.dtype 120 121 122class _OptimStateKey(NamedTuple): 123 """ 124 This represents an optimizer state key that may be used commonly across 125 ranks. It is based on the unflattened parameter names rather than parameter 126 IDs to make it independent of each rank's own optimizer construction. 127 """ 128 129 unflat_param_names: Tuple[str, ...] 130 is_fsdp_managed: bool 131 132 133def _unflatten_optim_state( 134 fsdp_param_info: FSDPParamInfo, 135 flat_param_state: Dict[str, Any], 136 to_save: bool, 137 shard_state: bool, 138 cpu_offload: bool, 139) -> List[Dict[str, Any]]: 140 """ 141 Unflattens the optimizer state, consisting of the "state" part and the 142 "param_groups" part. Unflattening the "state" part involves consolidating 143 the state on the target rank and remapping from flattened to unflattened 144 parameter IDs, and the "param_groups" part only involves remapping from 145 flattened to unflattened parameter IDs. 146 147 Args: 148 fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a 149 mapping from FQN to original parameter index. 150 flat_param_state (Dict[str, Any]): Entry for the flat parameter in the 151 "state" part of the optimizer state dict. 152 to_save (bool): Whether to save the state on this rank. 153 154 Returns: 155 List[Dict[str, Any]]: A :class:`list` holding the entries in the 156 "state" part of the optimizer state dict corresponding to the 157 unflattened parameters comprising the flat parameter if on the target 158 rank or an empty :class:`list` otherwise. The final optimizer state 159 dict will need to map these entries using the proper unflattened 160 parameter IDs. 161 """ 162 assert ( 163 not shard_state or to_save 164 ), "If ``shard_state`` is True, ``to_save`` has to be True." 165 consolidated_state = _communicate_optim_state( 166 fsdp_param_info, 167 flat_param_state, 168 ) 169 if to_save: 170 unflat_param_state = _unflatten_communicated_optim_state( 171 fsdp_param_info, 172 consolidated_state, 173 shard_state, 174 ) 175 for optim_state in unflat_param_state: 176 # We can't use .items() below cuz we'd run into a concurrent modification error 177 if cpu_offload: 178 for key in list(optim_state.keys()): 179 state = optim_state[key] 180 if not isinstance(state, torch.Tensor): 181 continue 182 optim_state[key] = state.cpu() 183 return unflat_param_state 184 else: 185 return [] 186 187 188def _is_zero_dim_tensor(x: Any) -> bool: 189 return torch.is_tensor(x) and x.dim() == 0 190 191 192def _communicate_optim_state( 193 fsdp_param_info: FSDPParamInfo, 194 flat_param_state: Dict[str, Any], 195) -> _ConsolidatedOptimState: 196 """ 197 Communicates the optimizer state for a flat parameter across ranks. All 198 ranks will hold the entire non-sharded optimizer state on GPU. 199 200 If ``N`` is the number of tensor optimizer states in the optimizer state 201 dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` 202 otherwise (where the plus 1 comes from all-gathering the padding per rank). 203 204 Args: 205 fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a 206 mapping from FQN to original parameter index. 207 flat_param_state (Dict[str, Any]): The entry in the "state" part of the 208 optimizer state dict corresponding to the flat parameter. 209 210 Returns: 211 ConsolidatedOptimState: Consolidated optimizer state for the target 212 flat parameter. 213 """ 214 fsdp_state = fsdp_param_info.state 215 flat_param = fsdp_param_info.handle.flat_param 216 state = _ConsolidatedOptimState() 217 tensor_state, zero_dim_tensor_state, non_tensor_state = ( 218 state.tensor_state, 219 state.zero_dim_tensor_state, 220 state.non_tensor_state, 221 ) 222 223 for state_name, value in sorted_items(flat_param_state): 224 # Positive-dimension tensor state: communicate across ranks 225 if torch.is_tensor(value) and value.dim() > 0: 226 # If the parameter is not sharded, then neither is the 227 # positive-dimension tensor state, so no need to communicate it -- 228 # we take the target rank's value 229 if ( 230 fsdp_state.world_size == 1 231 or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD 232 ): 233 tensor_state[state_name] = value 234 continue 235 assert ( 236 fsdp_state.compute_device is not None 237 ), "compute_device has not been initialized" 238 if value.device.type != fsdp_state.compute_device.type: 239 value = value.to(fsdp_state.compute_device) 240 # Assume that positive-dimension tensor optimizer state 241 # has the same shape as the sharded flat parameter 242 buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] 243 tensor_buffer = value.new_zeros(*buffer_size) 244 dist.all_gather_into_tensor( 245 tensor_buffer, value, group=fsdp_state.process_group 246 ) 247 fsdp_state._device_handle.synchronize() 248 unpadded_numel = cast( 249 nn.Parameter, flat_param._unpadded_unsharded_size 250 ).numel() 251 tensor_state[state_name] = tensor_buffer[:unpadded_numel] 252 # Zero-dimension tensor state and non-tensor state: take this rank's 253 # value directly 254 else: 255 if _is_zero_dim_tensor(value): 256 zero_dim_tensor_state[state_name] = value.detach().clone() 257 else: 258 non_tensor_state[state_name] = value 259 return state 260 261 262def _unflatten_communicated_optim_state( 263 fsdp_param_info: FSDPParamInfo, 264 state: _ConsolidatedOptimState, 265 shard_state: bool, 266) -> List[Dict[str, Any]]: 267 """ 268 Unflattens the communicated optimizer state (given by ``tensor_state``, 269 ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat 270 parameter. This should only be called on the target rank. 271 272 Args: 273 fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a 274 mapping from FQN to original parameter index. 275 state (_ConsolidatedOptimState): Consolidated optimizer state. 276 277 Returns: 278 List[Dict[str, Any]]: A :class:`list` holding the entries in the 279 "state" part of the optimizer state dict corresponding to the 280 unflattened parameters comprising the flat parameter. The final 281 optimizer state dict will need to map these entries using the proper 282 unflattened parameter IDs. 283 """ 284 fsdp_state = fsdp_param_info.state 285 handle = fsdp_param_info.handle 286 flat_param = handle.flat_param 287 unflat_param_state: List[Dict[str, Any]] = [] 288 flat_param_views: Dict[str, Iterator] = {} 289 num_unflat_params = flat_param._num_params 290 tensor_state, zero_dim_tensor_state, non_tensor_state = ( 291 state.tensor_state, 292 state.zero_dim_tensor_state, 293 state.non_tensor_state, 294 ) 295 296 for _ in range(num_unflat_params): 297 unflat_state_param = {} 298 # Add positive-dimension tensor state: unflatten with views 299 for state_name, flat_tensor in sorted_items(tensor_state): 300 views_generated = state_name in flat_param_views 301 if not views_generated: 302 views = handle._get_unflat_views(flat_tensor) 303 flat_param_views[state_name] = views 304 else: 305 views = flat_param_views[state_name] 306 optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views) 307 if shard_state: 308 osd_config = fsdp_state._optim_state_dict_config 309 if getattr(osd_config, "_use_dtensor", False): 310 assert fsdp_state._device_mesh is not None 311 optim_state = _ext_chunk_dtensor( 312 optim_state, 313 fsdp_state.rank, 314 fsdp_state._device_mesh, 315 fsdp_state._fsdp_extension, 316 ) 317 else: 318 assert fsdp_state.process_group is not None 319 optim_state = _ext_chunk_tensor( 320 optim_state, 321 fsdp_state.rank, 322 fsdp_state.world_size, 323 fsdp_state._device_handle.device_count(), 324 fsdp_state.process_group, 325 fsdp_state._fsdp_extension, 326 ) 327 unflat_state_param[state_name] = optim_state 328 329 # Add zero-dimension tensor state: take the target rank's value 330 for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state): 331 unflat_state_param[state_name] = zero_dim_tensor 332 # Add non-tensor state: take the target rank's value 333 for state_name, non_tensor in sorted_items(non_tensor_state): 334 unflat_state_param[state_name] = non_tensor 335 unflat_param_state.append(unflat_state_param) 336 return unflat_param_state 337 338 339def _broadcast_processed_state( 340 fsdp_state: _FSDPState, 341 optim_state: Dict[str, Any], 342 group: Optional[dist.ProcessGroup], 343) -> Dict[str, Any]: 344 objects: List[Any] = [None] 345 if dist.get_rank(group) == 0: 346 objects[0] = tree_map_only( 347 torch.Tensor, 348 lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] 349 optim_state, 350 ) 351 dist.broadcast_object_list(objects, src=0, group=group) 352 if dist.get_rank(group) == 0: 353 return optim_state 354 else: 355 return objects[0] 356 357 358def _broadcast_state( 359 fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] 360) -> Any: 361 if dist.get_rank(group) == 0: 362 if not isinstance(state, torch.Tensor) or state.dim() == 0: 363 return state 364 tensor = state.to(fsdp_state.compute_device) 365 else: 366 if isinstance(state, torch.Tensor): 367 assert state.dim() == 0, ( 368 "For non-zero ranks, a tensor state should have zero dimension, " 369 "but got the state with shape {state.shape()}." 370 ) 371 return state 372 elif not isinstance(state, _PosDimTensorInfo): 373 return state 374 tensor = torch.zeros( 375 state.shape, dtype=state.dtype, device=fsdp_state.compute_device 376 ) 377 dist.broadcast(tensor, src=0, group=group) 378 return tensor 379 380 381def _shard_orig_param_state( 382 fsdp_param_info: FSDPParamInfo, 383 fqn: str, 384 optim_state: Dict[str, Any], 385) -> Dict[str, Any]: 386 """ 387 Shard the optimizer state for the original parameter with the name ``fqn``. 388 This API should only be used when ``use_orig_params`` is True. 389 """ 390 if not optim_state: 391 return {} 392 fsdp_state = fsdp_param_info.state 393 flat_param = fsdp_param_info.handle.flat_param 394 param_idx = fsdp_param_info.param_indices[fqn] 395 shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] 396 optim_state = _gather_state_dict( 397 optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device 398 ) 399 if not shard_param_info.in_shard: 400 return {} 401 # Flatten and shard the state. 402 new_optim_state: Dict[str, Any] = {} 403 intra_param_start_idx = shard_param_info.intra_param_start_idx 404 intra_param_end_idx = shard_param_info.intra_param_end_idx 405 for state_name, value in optim_state.items(): 406 if ( 407 torch.is_tensor(value) 408 and value.dim() > 0 409 and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD 410 ): 411 value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] 412 new_optim_state[state_name] = value 413 return new_optim_state 414 415 416def _flatten_optim_state_dict( 417 optim_state_dict: Dict[str, Any], 418 model: nn.Module, 419 use_orig_params: bool = False, 420 optim: Optional[torch.optim.Optimizer] = None, 421 rank0_only: bool = False, 422 group: Optional[dist.ProcessGroup] = None, 423) -> Dict[str, Any]: 424 """ 425 Flattens the full optimizer state dict, still keying by unflattened parameter 426 names. 427 428 If ``use_orig_params`` is True, each rank will have all FSDP-managed 429 parameters but some of these parameters may be empty due to the sharding. 430 For a regular optim.Optimizer, states for those empty parameters will 431 not be initialized. So, when aggregating the FQNs across ranks, no assert 432 will be raised on a rank even if it does not have all the states -- it is 433 valid and FSDP know how to aggregate them. However, FSDP has to ignore 434 handling those parameters that are not managed by FSDP and do not exist on 435 the local rank -- it is managed by other parallelism and FSDP does not 436 know ho to handle/aggregate them. 437 438 Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to 439 flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require 440 all the states even if the corresponding parameters are empty. To this end, 441 ``optim`` will be used to to get the initial state of the empty parameters. 442 ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or 443 NamedOptimizer. 444 445 Returns: 446 Dict[str, Any]: The flattened optimizer state dict. 447 """ 448 SimpleProfiler.reset() 449 450 unflat_osd = optim_state_dict 451 if "state" not in unflat_osd and not rank0_only: 452 raise ValueError( 453 '`optim_state_dict` must have the keys "state"' 454 "to be a valid optimizer state dict" 455 ) 456 param_to_fqns = _get_param_to_fqns(model) 457 fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) 458 fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state 459 460 # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. 461 if rank0_only: 462 unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) 463 464 # Construct the "state" part 465 flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} 466 unflat_osd_state = unflat_osd["state"] 467 all_state_keys = set(unflat_osd_state.keys()) 468 469 for param, fqns in param_to_fqns.items(): 470 fqn = fqns[0] 471 if fqn not in unflat_osd_state: 472 continue 473 all_state_keys.difference_update(fqns) 474 475 if rank0_only: 476 for fqn in fqns: 477 if not unflat_osd_state[fqn]: 478 continue 479 for state_name in unflat_osd_state[fqn].keys(): 480 unflat_osd_state[fqn][state_name] = _broadcast_state( 481 fsdp_state, unflat_osd_state[fqn][state_name], group=group 482 ) 483 fqn = fqns[0] 484 if fqn in fqn_to_fsdp_param_info: 485 fsdp_param_info = fqn_to_fsdp_param_info[fqn] 486 if use_orig_params: 487 with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): 488 flat_state = _shard_orig_param_state( 489 fsdp_param_info, 490 fqn, 491 unflat_osd_state[fqn], 492 ) 493 else: 494 flat_state = _flatten_optim_state( 495 fsdp_param_info, 496 unflat_osd_state, 497 fqns, 498 ) 499 key = _OptimStateKey(tuple(fqns), True) 500 # Only include non-empty states since as expected by 501 # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer 502 # or NamedOptimizer. 503 if flat_state: 504 flat_osd_state[key] = flat_state 505 elif use_orig_params: 506 assert ( 507 len(fqns) == 1 508 ), f"use_orig_params is True but there are multiple FQNs, {fqns}." 509 if optim is not None: # NamedOptimizer or KeyedOptimizer case. 510 state = optim.state.get(param, None) # type: ignore[call-overload] 511 if state is not None: 512 flat_osd_state[key] = copy.deepcopy(state) 513 else: 514 warnings.warn( 515 f"optim_state[{key}] is not on rank{fsdp_state.rank}." 516 ) 517 518 else: 519 raise RuntimeError( 520 f"The state of {key} is empty. This should happen when " 521 "use_orig_params=True." 522 ) 523 else: # do not flatten non-FSDP parameters' states 524 assert len(fqns) == 1 525 key = _OptimStateKey(tuple(fqns), False) 526 flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) 527 528 if rank0_only: 529 for fqn in fqns: 530 if not unflat_osd_state[fqn]: 531 continue 532 for state_name, param_state in list(unflat_osd_state[fqn].items()): 533 if fsdp_state.rank > 0: 534 # Deference the tensor so that PyTorch can collect the memory. 535 del unflat_osd_state[fqn][state_name] 536 else: 537 # Move the tensor in the original osd back to CPU to make the 538 # original osd unaffected. 539 unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][ 540 state_name 541 ].cpu() 542 543 # Handle user-defined state, states that are not associated with parameters. 544 for key in all_state_keys: 545 user_state = unflat_osd_state[key] 546 if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: 547 user_state = _broadcast_state(fsdp_state, user_state, group=group) 548 flat_osd_state[key] = copy.copy(user_state) 549 550 SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") 551 # Construct the "param_groups" part -- copy as is since it will be 552 # rekeyed later according to the target rank's optimizer 553 # Only copy param_groups if it exists in unflat_osd 554 if "param_groups" in unflat_osd: 555 flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) 556 return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} 557 else: 558 return {"state": flat_osd_state} 559 560 561def _flatten_optim_state( 562 fsdp_param_info: FSDPParamInfo, 563 unflat_osd_state: Dict[str, Dict[str, Any]], 564 unflat_param_names: List[str], 565) -> Dict[str, Any]: 566 """ 567 Flattens the optimizer state in ``full_optim_state_dict`` for a single 568 flat parameter in ``fsdp_param_info`` corresponding to the unflattened 569 parameter names in ``unflat_param_names``. 570 571 Args: 572 fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a 573 mapping from FQN to original parameter index. 574 unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the 575 optimizer state dict corresponding to the unflattened parameters. 576 unflat_param_names (List[str]): A :class:`list` of unflattened 577 parameter names corresponding to the flat parameter ``flat_param``. 578 579 Returns: 580 Dict[str, Any]: A :class:`dict` mapping state names to their values for 581 a particular flat parameter. The sharded optimizer state dict's "state" 582 part will map a key to this returned value. 583 """ 584 fsdp_state = fsdp_param_info.state 585 handle = fsdp_param_info.handle 586 flat_param = handle.flat_param 587 num_unflat_params = len(unflat_param_names) 588 assert num_unflat_params > 0, ( 589 "Expects at least one unflattened parameter corresponding to the " 590 "flat parameter" 591 ) 592 unflat_param_shapes = flat_param._shapes 593 num_unflat_param_shapes = len(unflat_param_shapes) 594 assert ( 595 num_unflat_params == num_unflat_param_shapes 596 ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" 597 598 # Check if these unflattened parameters have any optimizer state 599 has_state = [ 600 bool(unflat_param_name in unflat_osd_state) 601 for unflat_param_name in unflat_param_names 602 ] 603 # If none of the unflattened parameters comprising this flat parameter have 604 # any state, then we do not want an entry in the optimizer state dict 605 if not any(has_state): 606 return {} # no need to flatten any state 607 # There may still be some unflattened parameters with state and some 608 # without 609 unflat_param_states = [ 610 _gather_state_dict( 611 unflat_osd_state[unflat_param_name], 612 pg=fsdp_state.process_group, 613 device=fsdp_state.compute_device, 614 ) 615 if unflat_param_name in unflat_osd_state 616 else None 617 for unflat_param_name in unflat_param_names 618 ] 619 # Check that the unflattened parameters have the same state names 620 state_names = None 621 for unflat_param_state in unflat_param_states: 622 if unflat_param_state is None: 623 continue 624 if state_names is None: 625 state_names = set(unflat_param_state.keys()) 626 else: 627 if state_names != set(unflat_param_state.keys()): 628 raise ValueError( 629 "Differing optimizer state names for the unflattened " 630 f"parameters: {unflat_param_names}" 631 ) 632 assert state_names is not None 633 634 # Flatten the state 635 flat_state: Dict[str, Any] = {} 636 for state_name in state_names: 637 state_values = [ 638 unflat_param_state[state_name] if unflat_param_state is not None else None 639 for unflat_param_state in unflat_param_states 640 ] 641 non_none_state_values = [v for v in state_values if v is not None] 642 # If all ranks have None, this is a None value 643 if not non_none_state_values: 644 flat_state[state_name] = None 645 continue 646 are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True 647 for v in non_none_state_values: 648 are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 649 are_zero_dim_tensors &= _is_zero_dim_tensor(v) 650 are_non_tensors &= not torch.is_tensor(v) 651 types = {type(v) for v in non_none_state_values} 652 if len(types) != 1 or not ( 653 are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors 654 ): 655 raise ValueError( 656 f"Differing optimizer state types for state {state_name}, " 657 f"values {non_none_state_values}, and unflattened parameter " 658 f"names {unflat_param_names}" 659 ) 660 if are_pos_dim_tensors: 661 flat_tensor = _flatten_tensor_optim_state( 662 state_name, 663 state_values, 664 unflat_param_names, 665 unflat_param_shapes, 666 handle, 667 ) 668 # Shard the flattened tensor immediately to minimize max memory 669 # usage 670 if ( 671 fsdp_state.world_size != 1 672 and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD 673 ): 674 sharded_flat_tensor, _ = FlatParamHandle._get_shard( 675 flat_tensor, 676 fsdp_state.rank, 677 fsdp_state.world_size, 678 ) 679 else: 680 sharded_flat_tensor = flat_tensor 681 flat_state[state_name] = sharded_flat_tensor 682 elif are_zero_dim_tensors: 683 flat_state[state_name] = _flatten_zero_dim_tensor_optim_state( 684 state_name, 685 state_values, 686 unflat_param_names, 687 ) 688 else: 689 assert are_non_tensors 690 flat_state[state_name] = _flatten_non_tensor_optim_state( 691 state_name, 692 state_values, 693 unflat_param_names, 694 ) 695 696 return flat_state 697 698 699def _flatten_tensor_optim_state( 700 state_name: str, 701 pos_dim_tensors: List[torch.Tensor], 702 unflat_param_names: List[str], 703 unflat_param_shapes: Sequence[torch.Size], 704 handle: FlatParamHandle, 705) -> torch.Tensor: 706 """ 707 Flattens the positive-dimension tensor optimizer state given by the values 708 ``tensors`` for the state ``state_name`` for a single flat parameter 709 from ``handle`` corresponding to the unflattened parameter names 710 ``unflat_param_names`` and unflatted parameter shapes 711 ``unflat_param_shapes``. This flattens each unflattened parameter's tensor 712 state into one tensor. 713 714 NOTE: We use zero tensors for any unflattened parameters without state 715 since some value is required to fill those entries. This assumes that the 716 zero tensor is mathematically equivalent to having no state, which is true 717 for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all 718 optimizers. 719 720 Args: 721 state_name (str): Optimizer state name. 722 pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor 723 optimizer state values for the unflattened parameters corresponding 724 to the single flat parameter. 725 unflat_param_names (List[str]): A :class:`list` of unflattened 726 parameter names corresponding to the single flat parameter. 727 unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes 728 corresponding to the single flat parameter. 729 handle (FlatParamHandle): The flat parameter's handle. 730 731 Returns: 732 torch.Tensor: A flat tensor containing the optimizer state 733 corresponding to ``state_name`` constructed by concatenating the 734 unflattened parameter tensor states in ``pos_dim_tensors`` (using zero 735 tensors for any unflattened parameters without the state). 736 """ 737 flat_param = handle.flat_param 738 non_none_tensors = [t for t in pos_dim_tensors if t is not None] 739 # Check that all are tensors with the same dtype 740 dtypes = {t.dtype for t in non_none_tensors} 741 if len(dtypes) != 1: 742 raise ValueError( 743 "All unflattened parameters comprising a single flat " 744 "parameter must have positive-dimension tensor state with the " 745 f"same dtype but got dtypes {dtypes} for state {state_name} and " 746 f"unflattened parameter names {unflat_param_names}" 747 ) 748 dtype = next(iter(dtypes)) 749 # Check that each tensor state matches its parameter's shape 750 for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes): 751 if tensor is None and len(shape) == 0: 752 raise ValueError("Flattening a zero-dimension parameter is not supported") 753 elif tensor is not None and tensor.shape != shape: 754 raise ValueError( 755 "Tensor optimizer state does not have same shape as its " 756 f"parameter: {tensor.shape} {shape}" 757 ) 758 # Flatten the tensor states: we do not need to add any right-hand-side 759 # padding since the flat optimizer state tensor is sharded via 760 # `_get_shard()`, which pads the shard as needed (just like for the flat 761 # parameter) 762 cpu_device = torch.device("cpu") 763 tensors_to_flatten = [ 764 torch.flatten(state_value.to(cpu_device)) 765 if state_value is not None 766 else torch.flatten( 767 torch.zeros( 768 size=shape, 769 dtype=dtype, 770 device=cpu_device, 771 ) 772 ) 773 for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes) 774 ] 775 flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) 776 flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] 777 assert flat_tensor.shape == flat_param_shape, ( 778 f"tensor optim state: {flat_tensor.shape} " 779 f"flat parameter: {flat_param_shape}" 780 ) 781 return flat_tensor 782 783 784def _flatten_zero_dim_tensor_optim_state( 785 state_name: str, 786 zero_dim_tensors: List[torch.Tensor], 787 unflat_param_names: List[str], 788) -> torch.Tensor: 789 """ 790 Flattens the zero-dimension tensor optimizer state given by the values 791 ``zero_dim_tensors`` for the state ``state_name`` for a single flat 792 parameter corresponding to the unflattened parameter names 793 ``unflat_param_names`` by enforcing that all tensors are the same and using 794 that common value. 795 796 NOTE: The requirement that the tensors are the same across all unflattened 797 parameters comprising the flat parameter is needed to maintain the 798 invariant that FSDP performs the same computation as its non-sharded 799 equivalent. This means that none of the unflattened parameters can be 800 missing this state since imposing a value may differ from having no value. 801 For example, for Adam's "step", no value means maximum bias correction, 802 while having some positive value means less bias correction. 803 804 Args: 805 state_name (str): Optimizer state name. 806 zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state 807 for the unflattened parameters corresponding to the single 808 flat parameter. 809 unflat_param_names (List[str]): A :class:`list` of unflattened 810 parameter names corresponding to the single flat parameter. 811 812 Returns: 813 torch.Tensor: A zero-dimensional tensor giving the value of the state 814 ``state_name`` for all unflattened parameters corresponding to the 815 names ``unflat_param_names``. 816 """ 817 non_none_tensors = [t for t in zero_dim_tensors if t is not None] 818 # Enforce that all have the same value and dtype 819 values_set = {t.item() if t is not None else None for t in zero_dim_tensors} 820 dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors} 821 if ( 822 len(non_none_tensors) != len(zero_dim_tensors) 823 or len(values_set) != 1 824 or len(dtypes) != 1 825 ): 826 raise ValueError( 827 "All unflattened parameters comprising a single flat " 828 "parameter must have scalar state with the same value and dtype " 829 f"but got values {values_set} and dtypes {dtypes} for state " 830 f"{state_name} and unflattened parameter names " 831 f"{unflat_param_names}" 832 ) 833 value = next(iter(values_set)) 834 dtype = next(iter(dtypes)) 835 return torch.tensor(value, dtype=dtype, device=torch.device("cpu")) 836 837 838def _flatten_non_tensor_optim_state( 839 state_name: str, 840 non_tensors: List[Any], 841 unflat_param_names: List[str], 842) -> Any: 843 """ 844 Flattens the non-tensor optimizer state given by the values ``non_tensors`` 845 for the state ``state_name`` for a single flat parameter corresponding 846 to the unflattened parameter names ``unflat_param_names`` by enforcing that 847 all values are the same and using that common value. 848 849 See the note in :func:`_flatten_zero_dim_tensor_optim_state`. 850 851 Args: 852 state_name (str): Optimizer state name. 853 non_tensors (List[Any]): Non-tensor optimizer state for the unflattened 854 parameters corresponding to the single flat parameter. 855 unflat_param_names (List[str]): A :class:`list` of unflattened 856 parameter names corresponding to the single flat parameter. 857 858 Returns: 859 Any: A non-tensor giving the value of the state ``state_name`` for all 860 unflattened parameters corresponding to the names 861 ``unflat_param_names``. 862 """ 863 non_none_non_tensors = [nt for nt in non_tensors if nt is not None] 864 # Enforce that all have the same value (same type already checked) 865 non_tensor_set = set(non_tensors) 866 if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1: 867 raise ValueError( 868 "All unflattened parameters comprising a single flat " 869 "parameter must have scalar state with the same value and dtype " 870 f"but got values {non_tensor_set} for state {state_name} and " 871 f"unflattened parameter names {unflat_param_names}" 872 ) 873 non_tensor = next(iter(non_tensor_set)) 874 return non_tensor 875 876 877def _rekey_sharded_optim_state_dict( 878 sharded_osd: Dict[str, Any], 879 model: nn.Module, 880 optim: torch.optim.Optimizer, 881 optim_input: Optional[ 882 Union[ 883 List[Dict[str, Any]], 884 Iterable[nn.Parameter], 885 ] 886 ], 887 using_optim_input: bool, 888 is_named_optimizer: bool = False, 889) -> Dict[str, Any]: 890 """ 891 Rekeys the optimizer state dict from unflattened parameter names to flat 892 parameter IDs according to the calling rank's ``optim``, which may be 893 different across ranks. In particular, the unflattened parameter names are 894 represented as :class:`_OptimStateKey` s. 895 """ 896 param_to_fqns = _get_param_to_fqns(model) 897 flat_param_to_fqn = _get_flat_param_to_fqn(model) 898 param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast( 899 Dict[nn.Parameter, Union[int, str]], 900 ( 901 _get_param_to_param_id_from_optim_input(model, optim_input) 902 if using_optim_input 903 else _get_param_to_param_key( 904 optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn 905 ) 906 ), 907 ) 908 # All parameter keys in `param_to_param_key` should be in 909 # `param_to_fqns` -- strict inequality follows when not all parameters are 910 # passed to the optimizer 911 assert len(param_to_param_key) <= len(param_to_fqns) 912 913 unflat_param_names_to_flat_param_key: Dict[ 914 Tuple[str, ...], Union[int, str] 915 ] = {} # for "state" 916 unflat_param_name_to_flat_param_key: Dict[ 917 str, Union[int, str] 918 ] = {} # for "param_groups" 919 for param, unflat_param_names in param_to_fqns.items(): 920 if param not in param_to_param_key: 921 # This parameter was not passed to the optimizer 922 continue 923 flat_param_key = param_to_param_key[param] 924 unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key 925 for unflat_param_name in unflat_param_names: 926 unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key 927 928 sharded_osd_state = sharded_osd["state"] 929 rekeyed_osd_state: Dict[Union[str, int], Any] = {} 930 for key, param_state in sharded_osd_state.items(): 931 if isinstance(key, str): 932 rekeyed_osd_state[key] = param_state 933 continue 934 flat_param_key = unflat_param_names_to_flat_param_key.get( 935 key.unflat_param_names, key.unflat_param_names 936 ) 937 rekeyed_osd_state[flat_param_key] = param_state 938 939 # Only process param_groups if it exists in sharded_osd 940 if "param_groups" in sharded_osd: 941 rekeyed_osd_param_groups: List[Dict[str, Any]] = [] 942 for unflat_param_group in sharded_osd["param_groups"]: 943 flat_param_group = copy.deepcopy(unflat_param_group) 944 flat_param_keys = sorted( 945 { 946 unflat_param_name_to_flat_param_key[unflat_param_name] 947 for unflat_param_name in unflat_param_group["params"] 948 } 949 ) 950 flat_param_group["params"] = flat_param_keys 951 rekeyed_osd_param_groups.append(flat_param_group) 952 return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups} 953 else: 954 return {"state": rekeyed_osd_state} 955 956 957def _get_param_id_to_param_from_optim_input( 958 model: nn.Module, 959 optim_input: Optional[ 960 Union[ 961 List[Dict[str, Any]], 962 Iterable[nn.Parameter], 963 ] 964 ] = None, 965) -> Dict[int, nn.Parameter]: 966 """ 967 Constructs a mapping from parameter IDs to parameters. This may be used 968 both for models with ``FlatParameter`` s and without. 969 970 NOTE: This method is only preserved for backward compatibility. The method 971 :meth:`_get_param_key_to_param` is the preferred code path that does not 972 rely on ``optim_input``. 973 974 NOTE: We critically assume that, whether the optimizer input is a list of 975 parameters or a list of parameter groups, :class:`torch.optim.Optimizer` 976 enumerates the parameter IDs in order. In other words, for a parameter list 977 input, the parameter IDs should be in that list order, and for a parameter 978 groups input, the parameter IDs should be in order within each parameter 979 group and in order across parameter groups. 980 981 Args: 982 model (nn.Module): Model whose parameters are passed into the 983 optimizer. 984 optim_input (Optional[Union[List[Dict[str, Any]], 985 Iterable[nn.Parameter]]]): Input passed into the optimizer 986 representing either a :class:`list` of parameter groups or an 987 iterable of parameters; if ``None``, then this method assumes the 988 input was ``model.parameters()``. (Default: ``None``) 989 990 Returns: 991 List[nn.Parameter]: Mapping from parameter IDs to parameters, 992 where the parameter ID is implicitly the index in the :class:`list`. 993 """ 994 # Assume the standard case of passing `model.parameters()` to the optimizer 995 # if `optim_input` is not specified 996 if optim_input is None: 997 return dict(enumerate(model.parameters())) 998 try: 999 params = cast(List[nn.Parameter], list(optim_input)) 1000 except TypeError as e: 1001 raise TypeError( 1002 "Optimizer input should be an iterable of Tensors or dicts, " 1003 f"but got {optim_input}" 1004 ) from e 1005 if len(params) == 0: 1006 raise ValueError("Optimizer input should not be empty") 1007 1008 # Check if the optimizer input represents tensors or parameter groups 1009 all_tensors = True 1010 all_dicts = True 1011 for param in params: 1012 all_tensors &= isinstance(param, torch.Tensor) 1013 all_dicts &= isinstance(param, dict) 1014 if not all_tensors and not all_dicts: 1015 raise TypeError("Optimizer input should be an iterable of Tensors or dicts") 1016 if all_tensors: 1017 return dict(enumerate(params)) 1018 assert all_dicts 1019 param_id_to_param: List[nn.Parameter] = [] 1020 for param_group in params: 1021 has_params_key = "params" in param_group # type: ignore[operator] 1022 assert has_params_key, ( 1023 'A parameter group should map "params" to a list of the ' 1024 "parameters in the group" 1025 ) 1026 # Implicitly map `flat_param_id` (current length of the list) to 1027 # `param` 1028 param_id_to_param.extend(param_group["params"]) # type: ignore[index] 1029 return dict(enumerate(param_id_to_param)) 1030 1031 1032def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]: 1033 """ 1034 Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes 1035 from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical" 1036 because ``FlatParameter`` s do not come from the original module but are 1037 registered only after FSDP has been applied. This function returns the FSDP-given 1038 name for the ``FlatParameter`` (usually module._flat_param) as opposed to the 1039 canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``). 1040 1041 Consequently, this function will only return a non-empty mapping if FSDP was 1042 applied with ``use_orig_params=False`` as, otherwise, the original parameters 1043 are used within the module and there would be no ``FlatParameter`` s in the module. 1044 1045 """ 1046 1047 def module_fn(module, prefix, tree_level, flat_param_to_fqn): 1048 for param_name, param in _named_parameters_with_duplicates( 1049 module, recurse=False 1050 ): 1051 if not isinstance(param, FlatParameter): 1052 continue 1053 fqn = clean_tensor_name(prefix + param_name) 1054 flat_param_to_fqn[param] = fqn 1055 1056 def return_fn(flat_param_to_fqn): 1057 return flat_param_to_fqn 1058 1059 flat_param_to_fqn_ret: Dict[FlatParameter, str] = {} 1060 return _apply_to_modules( 1061 model, 1062 module_fn, 1063 return_fn, 1064 [fqn for fqn, _ in _named_parameters_with_duplicates(model)], 1065 flat_param_to_fqn_ret, 1066 ) 1067 1068 1069def _get_param_key_to_param( 1070 optim: torch.optim.Optimizer, 1071 model: Optional[nn.Module] = None, 1072 is_named_optimizer: bool = False, 1073 param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None, 1074 flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None, 1075) -> Dict[Union[int, str], nn.Parameter]: 1076 """ 1077 Constructs a mapping from parameter keys to parameters. For the regular 1078 optimizers, the keys are parameter IDs. For NamedOptimizer, the keys 1079 are FQNs. This API may be used both for models with ``FlatParameter`` s and 1080 without. 1081 """ 1082 clean_fqn_to_curr_fqn: Dict[str, str] = {} 1083 if is_named_optimizer: 1084 assert ( 1085 param_to_fqns is not None and flat_param_to_fqn is not None 1086 ), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." 1087 assert model is not None 1088 for key, _ in _named_parameters_with_duplicates(model): 1089 clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key 1090 1091 param_key_to_param: Dict[Union[str, int], nn.Parameter] = {} 1092 pid = 0 1093 for param_group in optim.param_groups: 1094 if is_named_optimizer: 1095 for param in param_group["params"]: 1096 assert flat_param_to_fqn is not None 1097 if param in flat_param_to_fqn: 1098 # FlatParameter case 1099 key = flat_param_to_fqn[param] 1100 else: 1101 assert param_to_fqns is not None 1102 # use_orig_params case 1103 assert len(param_to_fqns[param]) == 1 1104 key = param_to_fqns[param][0] 1105 try: 1106 key = clean_fqn_to_curr_fqn[key] 1107 except KeyError as e: 1108 raise KeyError( 1109 f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}." 1110 ) from e 1111 param_key_to_param[key] = param 1112 else: 1113 for param in param_group["params"]: 1114 param_key_to_param[pid] = param 1115 pid += 1 1116 1117 return param_key_to_param 1118 1119 1120def _get_param_to_param_key( 1121 optim: torch.optim.Optimizer, 1122 model: Optional[nn.Module] = None, 1123 is_named_optimizer: bool = False, 1124 param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None, 1125 flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None, 1126) -> Dict[nn.Parameter, Union[int, str]]: 1127 """ 1128 Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API 1129 only supports the case where `optim` is a regular optimizer, not NamedOptimizer. 1130 So the parameter keys will be parameter ids. 1131 """ 1132 param_id_to_param = _get_param_key_to_param( 1133 optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn 1134 ) 1135 return {param: param_id for param_id, param in param_id_to_param.items()} 1136 1137 1138def _get_param_to_param_id_from_optim_input( 1139 model: nn.Module, 1140 optim_input: Optional[ 1141 Union[ 1142 List[Dict[str, Any]], 1143 Iterable[nn.Parameter], 1144 ] 1145 ] = None, 1146) -> Dict[nn.Parameter, int]: 1147 """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" 1148 param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) 1149 return {param: param_id for param_id, param in param_id_to_param.items()} 1150 1151 1152def _check_missing_keys_on_rank( 1153 r0_optim_state_keys: List[_OptimStateKey], 1154 optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]], 1155 param_key_to_param: Dict[Union[str, int], nn.Parameter], 1156 group: Optional[dist.ProcessGroup], 1157) -> None: 1158 # Ensure that all ranks have at least the optimizer states needed by 1159 # rank 0's optimizer 1160 missing_keys: List[_OptimStateKey] = [] 1161 for r0_optim_state_key in r0_optim_state_keys: 1162 if r0_optim_state_key not in optim_state_key_to_param_key: 1163 # A parameter from rank 0's optimizer does not exist for this 1164 # rank's optimizer 1165 missing_keys.append(r0_optim_state_key) 1166 continue 1167 param_key = optim_state_key_to_param_key[r0_optim_state_key] 1168 if isinstance(param_key, int): 1169 assert param_key >= 0 and param_key < len( 1170 param_key_to_param 1171 ), "Check the `param_key_to_param` construction" 1172 # We cannot use FSDPState.compute_device as this API is a global view. 1173 device = _get_pg_default_device(group) 1174 num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) 1175 dist.all_reduce(num_missing, group=group) 1176 if num_missing.item() > 0: 1177 obj_list = [None for _ in range(dist.get_world_size(group))] 1178 dist.all_gather_object(obj_list, missing_keys, group=group) 1179 error_msg = ( 1180 "FSDP currently requires each rank to have at least the " 1181 "optimizer states needed by rank 0's optimizer but some ranks " 1182 "are missing some of those states" 1183 ) 1184 for rank, keys in enumerate(obj_list): 1185 keys = cast(List[_OptimStateKey], keys) 1186 if len(keys) > 0: 1187 error_msg += ( 1188 f"\nRank {rank} is missing states for the parameters: " 1189 f"{[key.unflat_param_names for key in keys]}" 1190 ) 1191 raise RuntimeError(error_msg) 1192 1193 1194def _map_param_key_to_optim_keys( 1195 optim_state_dict: Dict[str, Any], 1196 group: Optional[dist.ProcessGroup], 1197 param_key_to_param: Dict[Union[int, str], nn.Parameter], 1198 param_to_fqns: Dict[nn.Parameter, List[str]], 1199 fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], 1200 merge_keys: bool = False, 1201) -> Tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]: 1202 """ 1203 Construct the local mapping between the ``_OptimStateKey`` and parameter keys 1204 and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0 1205 must contain all the ``_OptimStateKey``, an exception will be raised otherwise. 1206 Note that ``merge_keys`` should equal to ``use_orig_params``. 1207 """ 1208 rank = dist.get_rank(group) 1209 optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {} # local 1210 all_optim_state_keys: List[_OptimStateKey] = [] 1211 1212 for param_key, param in param_key_to_param.items(): 1213 # Do not include parameters without state to avoid empty mappings 1214 # just like in normal `torch.optim.Optimizer.state_dict()` 1215 if param_key not in optim_state_dict["state"]: 1216 continue 1217 fqns = param_to_fqns[param] 1218 is_fsdp_managed = isinstance(param, FlatParameter) 1219 if is_fsdp_managed: 1220 assert fqns[0] in fqn_to_fsdp_param_info, ( 1221 fqns[0], 1222 list(fqn_to_fsdp_param_info.keys()), 1223 ) 1224 is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info 1225 optim_state_key = _OptimStateKey( 1226 unflat_param_names=tuple(fqns), 1227 is_fsdp_managed=is_fsdp_managed, 1228 ) 1229 if rank == 0 or merge_keys: 1230 all_optim_state_keys.append(optim_state_key) 1231 optim_state_key_to_param_key[optim_state_key] = param_key 1232 1233 if merge_keys: 1234 all_keys: List[List[_OptimStateKey]] = [ 1235 [] for _ in range(dist.get_world_size(group)) 1236 ] 1237 dist.all_gather_object(all_keys, all_optim_state_keys, group=group) 1238 merge_all_optim_state_keys = [ 1239 key for local_keys in all_keys for key in local_keys 1240 ] 1241 all_optim_state_keys = sorted(set(merge_all_optim_state_keys)) 1242 else: 1243 key_obj_list: List[Optional[List[_OptimStateKey]]] = ( 1244 [all_optim_state_keys] if rank == 0 else [None] 1245 ) 1246 dist.broadcast_object_list(key_obj_list, src=0, group=group) 1247 assert key_obj_list[0] is not None 1248 all_optim_state_keys = key_obj_list[0] 1249 _check_missing_keys_on_rank( 1250 all_optim_state_keys, 1251 optim_state_key_to_param_key, 1252 param_key_to_param, 1253 group, 1254 ) 1255 1256 return all_optim_state_keys, optim_state_key_to_param_key 1257 1258 1259def _unflatten_param_groups( 1260 state_dict: Dict[str, Any], 1261 param_key_to_param: Dict[Union[int, str], nn.Parameter], 1262 param_to_fqns: Dict[nn.Parameter, List[str]], 1263) -> List[Dict[str, Any]]: 1264 param_groups: List[Dict[str, Any]] = [] 1265 for flat_param_group in state_dict["param_groups"]: 1266 unflat_param_group = copy.deepcopy(flat_param_group) 1267 param_group_params = [ 1268 param_key_to_param[flat_param_key] 1269 for flat_param_key in flat_param_group["params"] 1270 ] 1271 nested_unflat_param_names = [ 1272 param_to_fqns[param] for param in param_group_params 1273 ] 1274 unflat_param_group["params"] = [ 1275 unflat_param_name 1276 for unflat_param_names in nested_unflat_param_names 1277 for unflat_param_name in unflat_param_names 1278 ] # flatten the list of lists 1279 param_groups.append(unflat_param_group) 1280 return param_groups 1281 1282 1283def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool: 1284 """ 1285 Returns whether the state_dict is from a NamedOptimizer. 1286 This function checks that the keys in the state_dict['state'] are strings 1287 (which usually are FQNs) versus integers (which usually refer to param_ids 1288 from a vanilla torch.optim.Optimizer). 1289 """ 1290 state = optim_state_dict.get("state", None) 1291 if not state: 1292 # If we cannot find a state, assume it is not NamedOptimizer as 1293 # NamedOptimizer has eager initialization. 1294 return False 1295 try: 1296 key = next(iter(state.keys())) 1297 except Exception as e: 1298 raise Exception(optim_state_dict) from e # noqa: TRY002 1299 return isinstance(key, str) 1300 1301 1302@dataclass 1303class StateInfo: 1304 # The key of these dictionaries are the state name, e.g., `exp_avg`. 1305 tensors: Dict[str, _PosDimTensorInfo] 1306 scalar_tensors: Dict[str, torch.Tensor] 1307 non_tensors: Dict[str, Any] 1308 1309 1310def _allgather_state_info( 1311 fsdp_state: _FSDPState, 1312 input_states: Dict[str, Any], 1313) -> List[Dict[str, StateInfo]]: 1314 """ 1315 Given the ``input_states``, allgather StateInfo for each state. The function 1316 uses all_gather_object to gather StateInfo so no GPU tensors are sent. 1317 """ 1318 1319 processed_state_dict: Dict[str, StateInfo] = {} 1320 gathered_state_info: List[Dict[str, StateInfo]] = [ 1321 {} for _ in range(fsdp_state.world_size) 1322 ] 1323 1324 for fqn, optim_state in input_states.items(): 1325 # Allgather the scalar tensor state, non-tensor states and tensors metadata. 1326 processed_state = StateInfo({}, {}, {}) 1327 for state_name, value in sorted_items(optim_state): 1328 if torch.is_tensor(value): 1329 if value.dim() == 0: 1330 # Ensure that `step` is on CPU. 1331 processed_state.scalar_tensors[state_name] = value.cpu() 1332 else: 1333 processed_state.tensors[state_name] = _PosDimTensorInfo( 1334 value.shape, value.dtype 1335 ) 1336 else: 1337 processed_state.non_tensors[state_name] = value 1338 processed_state_dict[fqn] = processed_state 1339 dist.all_gather_object( 1340 gathered_state_info, 1341 processed_state_dict, 1342 group=fsdp_state.process_group, 1343 ) 1344 return gathered_state_info 1345 1346 1347def _convert_all_state_info( 1348 fsdp_param_info: FSDPParamInfo, 1349 gathered_state_info: List[Dict[str, StateInfo]], 1350 input_states: Dict[str, Any], 1351 output_states: Dict[str, Dict[str, Any]], 1352) -> Tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]: 1353 """ 1354 Given the ``gathered_state_info`` and ``input_states``, the API converted 1355 the StateInfo into the original state if the state is not a non-scalar 1356 tensor. For a multi-dimensional tensor, the local state will be stored in 1357 ``state_buffer`` in a correct order for later allgather purpose. 1358 """ 1359 1360 state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {} 1361 1362 for fqn, gathered_state in output_states.items(): 1363 state_info = [s[fqn] for s in gathered_state_info] 1364 all_tensor_states = sorted( 1365 {n for state in state_info for n in state.tensors.keys()} 1366 ) 1367 empty_ranks: Set[int] = set() 1368 dtype: Optional[torch.dtype] = None 1369 # First check all the non-scalar states and get the information of 1370 # states on each rank. 1371 for state_name in all_tensor_states: 1372 numels = [] 1373 _empty_ranks: Set[int] = set() 1374 for rank, object_state in enumerate(state_info): 1375 numels.append(0) 1376 info = object_state.tensors.get(state_name, None) 1377 if info is not None: 1378 numels[-1] = info.shape.numel() 1379 if not dtype: 1380 dtype = info.dtype 1381 else: 1382 assert dtype == info.dtype 1383 if numels[-1] == 0: 1384 _empty_ranks.add(rank) 1385 1386 assert not empty_ranks or empty_ranks == _empty_ranks 1387 empty_ranks = _empty_ranks 1388 if state_name not in state_buffers: 1389 state_buffers[state_name] = [ 1390 None for _ in fsdp_param_info.param_indices 1391 ] 1392 local_state = input_states[fqn].get(state_name, None) 1393 # N.B. We need to move the state to compute_device. The reason is 1394 # not yet clear and we need to figure out why the state may be on a 1395 # different device. 1396 if local_state is not None: 1397 local_state = local_state.to(fsdp_param_info.state.compute_device) 1398 state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state 1399 1400 # Restoring the scalar and non-tensor states. If the corresponding 1401 # non-scalar states do not exist on the rank, we also skip the scalar 1402 # non-tensor states on that rank. 1403 for rank, object_state in enumerate(state_info): 1404 if rank in empty_ranks: 1405 continue 1406 for name, non_tensor_value in object_state.non_tensors.items(): 1407 curr_non_tensor_value = gathered_state.get(name, None) 1408 assert ( 1409 curr_non_tensor_value is None 1410 or curr_non_tensor_value == non_tensor_value 1411 ), ( 1412 f"Rank {rank} has different values for {name}: {non_tensor_value}." 1413 + f" Other ranks: {curr_non_tensor_value}" 1414 ) 1415 gathered_state[name] = non_tensor_value 1416 1417 for name, scalar_tensor_value in object_state.scalar_tensors.items(): 1418 curr_scalar_tensor_value = gathered_state.get(name, None) 1419 assert curr_scalar_tensor_value is None or torch.equal( 1420 scalar_tensor_value, curr_scalar_tensor_value 1421 ), ( 1422 f"Rank {rank} has different values for {name}: {scalar_tensor_value}." 1423 + f" Other ranks: {curr_scalar_tensor_value}" 1424 ) 1425 gathered_state[name] = scalar_tensor_value 1426 1427 return dtype, state_buffers # type: ignore[possibly-undefined] 1428 1429 1430def _unflatten_orig_param_states( 1431 fsdp_param_info: FSDPParamInfo, 1432 output_states: Dict[str, Dict[str, Any]], 1433 state_name: str, 1434 shard_state: bool, 1435 to_save: bool, 1436 cpu_offload: bool, 1437) -> None: 1438 """ 1439 Given a output state dict, ``output_states``, which the keys are FQNs to the 1440 original parameters (not FlatParameters nor parmeter ID), and the values 1441 are gathered states, unflatten the states to the original dimensions. 1442 1443 This function performs the unflattening process in-place. 1444 """ 1445 if not to_save: 1446 return 1447 flat_param = fsdp_param_info.handle.flat_param 1448 fsdp_state = fsdp_param_info.state 1449 for fqn, gathered_state in output_states.items(): 1450 value = gathered_state[state_name] 1451 param_idx = fsdp_param_info.param_indices[fqn] 1452 1453 # TODO: This solution is not general and only apply to PTD TP solution. 1454 if isinstance(value, DTensor): 1455 placement = value.placements[0] 1456 # If gathered state is a DTensor and its TP placement is not Replicate(), we need to 1457 # gather the tensor on its TP dimension before chunking them into DTensor again. 1458 if placement != Replicate(): 1459 placement_dim = placement.dim # type: ignore[attr-defined] 1460 value_local = value.redistribute(placements=(Replicate(),)) 1461 reshape_size = list(flat_param._shapes[param_idx]) 1462 reshape_size[placement_dim] *= value.device_mesh.size(0) 1463 reshape_size = torch.Size(reshape_size) 1464 value = value.reshape(reshape_size) 1465 # If gathered state is a replicate DTensor, we directly reshape it. 1466 else: 1467 value = value.reshape(flat_param._shapes[param_idx]) 1468 else: 1469 # If gathered state is a tensor, we directly reshape it into unflatten state. 1470 value = value.reshape(flat_param._shapes[param_idx]) 1471 1472 if shard_state: 1473 osd_config = fsdp_state._optim_state_dict_config 1474 if getattr(osd_config, "_use_dtensor", False): 1475 assert fsdp_state._device_mesh is not None 1476 value = _ext_chunk_dtensor( 1477 value, 1478 fsdp_state.rank, 1479 fsdp_state._device_mesh, 1480 fsdp_state._fsdp_extension, 1481 ) 1482 else: 1483 assert fsdp_state.process_group is not None 1484 value = _ext_chunk_tensor( 1485 value, 1486 fsdp_state.rank, 1487 fsdp_state.world_size, 1488 fsdp_state._device_handle.device_count(), 1489 fsdp_state.process_group, 1490 fsdp_state._fsdp_extension, 1491 ) 1492 elif not cpu_offload: 1493 with SimpleProfiler.profile("clone"): 1494 value = value.detach().clone() 1495 1496 if cpu_offload: 1497 with SimpleProfiler.profile(SimpleProfiler.Type.D2H): 1498 value = value.cpu() 1499 gathered_state[state_name] = value 1500 1501 1502def _allgather_orig_param_states( 1503 fsdp_param_info: FSDPParamInfo, 1504 gathered_state_info: List[Dict[str, StateInfo]], 1505 input_states: Dict[str, Any], 1506 shard_state: bool, 1507 to_save: bool, 1508 cpu_offload: bool, 1509) -> Dict[str, Dict[str, Any]]: 1510 """ 1511 Given the ``gathered_state_info`` and ``input_states``, the API allgathers 1512 all tensor states and restore non-tensor states from ``gathered_state_info``. 1513 """ 1514 fsdp_state = fsdp_param_info.state 1515 if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: 1516 logger.info( 1517 "Memory Summary before calling to _allgather_orig_param_states %s", 1518 fsdp_state._device_handle.memory_summary(), 1519 ) 1520 1521 output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} 1522 1523 dtype, state_buffers = _convert_all_state_info( 1524 fsdp_param_info, gathered_state_info, input_states, output_states 1525 ) 1526 1527 if len(state_buffers) == 0: 1528 return output_states 1529 1530 has_state_params: List[bool] = [ 1531 True if fqn in output_states else False 1532 for fqn, idx in fsdp_param_info.param_indices.items() 1533 ] 1534 1535 # Loop through the ``state_buffers`` and construct the flattened, concatenated, 1536 # sharded states. The size of the constructed state will be the same size as 1537 # flat_param (also sharded). 1538 # Then we perform an allgather_into_tensor to get the full flat_param state. 1539 # The full flat_param state is the result of concatenation of multiple states 1540 # the order of of flat_param._fqns. 1541 # The final step is to split the flat_param state into original param states 1542 # and return the result. 1543 flat_param = fsdp_param_info.handle.flat_param 1544 empty_func = functools.partial( 1545 torch.empty, dtype=dtype, device=fsdp_state.compute_device 1546 ) 1547 gathered_tensor = empty_func(flat_param._padded_unsharded_size) 1548 # Synchronize can be slow but this will be easier for us to debug. 1549 fsdp_state._device_handle.synchronize() 1550 for state_name, buffers in state_buffers.items(): 1551 local_buffers: List[torch.Tensor] = [] 1552 begin = fsdp_state.rank * flat_param._sharded_size.numel() 1553 # End is inclusive. 1554 end = begin + flat_param._sharded_size.numel() - 1 1555 # param_idx corresponds to the parameter index in the FlatParameter. 1556 mem_offset, param_idx = 0, 0 1557 for numel, is_padding in zip( 1558 flat_param._numels_with_padding, flat_param._is_padding_mask 1559 ): 1560 frozen_and_no_state = not is_padding and ( 1561 not fsdp_param_info.param_requires_grad[param_idx] 1562 and not has_state_params[param_idx] 1563 ) 1564 1565 if is_padding or frozen_and_no_state: 1566 # This memory range is a padding or the param is frozen and does 1567 # not require gradient. For the later case, we treat it as a 1568 # padding and add empty values to the local_buffers. 1569 1570 padding_begin, padding_end = mem_offset, mem_offset + numel - 1 1571 if padding_begin <= begin <= padding_end: 1572 # The range is an align padding before the first parameter in 1573 # the shard. The shard includes parts of this align padding. 1574 padding_len = ( 1575 padding_end - begin + 1 1576 if end >= padding_end 1577 else end - begin + 1 1578 ) 1579 elif padding_begin <= end <= padding_end: 1580 # The range is an align padding after the last parameter in 1581 # the shard. The shard includes parts of this align padding. 1582 padding_len = ( 1583 end - padding_begin + 1 1584 if begin <= padding_begin 1585 else end - begin + 1 1586 ) 1587 elif begin < padding_begin <= padding_end < end: 1588 # The range is an align padding that is completely in the 1589 # shard. 1590 padding_len = numel 1591 else: 1592 padding_len = 0 1593 if padding_len: 1594 local_buffers.append(empty_func(padding_len)) 1595 1596 if not is_padding: 1597 # This memory range is a parameter in FlatParameter. So there 1598 # should be an corresponding state in the optimizer unless the 1599 # parameter is frozen, which we treat it as a padding above. 1600 1601 # We need to check if this rank owns the buffer. If this is None: 1602 # 1.) the rank does not own any part of the original parameter. 1603 # As a result, there is no corresponding optimizer state on 1604 # the rank as well. 1605 # 2.) the parameter is frozen AND no optimizer state for the 1606 # parameter. If a parameter is frozen, there can still be 1607 # optimizer state if the parameter is not frozen in the 1608 # previous steps. 1609 if buffers[param_idx] is not None: 1610 local_buffers.append(cast(torch.Tensor, buffers[param_idx])) 1611 param_idx += 1 1612 1613 mem_offset += numel 1614 1615 shard_numel_padded = flat_param._sharded_size.numel() - ( 1616 sum(t.numel() for t in local_buffers) 1617 ) 1618 1619 assert flat_param._shard_numel_padded == shard_numel_padded, ( 1620 "Manually calculated _sharded_numel_padded is incorrect. " 1621 f"_shard_numel_padded={flat_param._shard_numel_padded}, " 1622 f"shard_numel_padded={shard_numel_padded}, " 1623 f"_sharded_size.numel={flat_param._sharded_size.numel()}, " 1624 f"_numels_with_padding={flat_param._numels_with_padding}, " 1625 f"begin={begin}, end={end}," 1626 ) 1627 if shard_numel_padded > 0: 1628 # Add right-handed padding. 1629 local_buffers.append(empty_func(shard_numel_padded)) 1630 local_shard = torch.cat(local_buffers) 1631 assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), ( 1632 "The size of local shard times the world size should equal to the " 1633 "gathered tensor size. The inconsistency may be from a bug of " 1634 "FlatParameter's metadata or the reconstruction logic in optimizer " 1635 "state dict." 1636 ) 1637 fsdp_state._device_handle.synchronize() 1638 with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): 1639 dist.all_gather_into_tensor( 1640 gathered_tensor, local_shard, group=fsdp_state.process_group 1641 ) 1642 # Synchronize can be slow but this will be easier for us to debug. 1643 fsdp_state._device_handle.synchronize() 1644 1645 unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()] 1646 flat_param_handle = fsdp_param_info.handle 1647 orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor) 1648 assert len(orig_states) == len(fsdp_param_info.param_indices), ( 1649 "The number of parameters from FlatParameter is not consistent to " 1650 "the number of states used by optimizer state dict reconstruction " 1651 "logic." 1652 ) 1653 for fqn, idx in fsdp_param_info.param_indices.items(): 1654 if fsdp_param_info.param_requires_grad[idx] or fqn in output_states: 1655 output_states[fqn][state_name] = orig_states[idx] 1656 1657 _unflatten_orig_param_states( 1658 fsdp_param_info, 1659 output_states, 1660 state_name, 1661 shard_state, 1662 to_save, 1663 cpu_offload, 1664 ) 1665 1666 del gathered_tensor 1667 return output_states 1668 1669 1670def _gather_all_orig_param_state( 1671 fsdp_param_info: FSDPParamInfo, 1672 input_states: Dict[str, Any], 1673 shard_state: bool, 1674 to_save: bool, 1675 cpu_offload: bool, 1676) -> Dict[str, Any]: 1677 """ 1678 Given a optimizer state dict, ``input_states``, which the keys are FQNs to the 1679 original parameters (not FlatParameters nor parmeter ID), gather all the 1680 states and unflatten them to the original dimensions. Note that all the 1681 params referred by the ``input_states`` must be managed by FSDP. 1682 """ 1683 fsdp_state = fsdp_param_info.state 1684 if ( 1685 fsdp_state.world_size == 1 1686 or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD 1687 ): 1688 return input_states if to_save else {} 1689 1690 with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): 1691 with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ): 1692 gathered_state_info = _allgather_state_info(fsdp_state, input_states) 1693 output_states = _allgather_orig_param_states( 1694 fsdp_param_info, 1695 gathered_state_info, 1696 input_states, 1697 shard_state, 1698 to_save, 1699 cpu_offload, 1700 ) 1701 if to_save: 1702 for key, idx in fsdp_param_info.param_indices.items(): 1703 if key in output_states: 1704 continue 1705 if not fsdp_param_info.param_requires_grad[idx]: 1706 continue 1707 1708 raise RuntimeError( 1709 f"{key} is not in the output state. " 1710 "The FSDPParamInfo has the param keys " 1711 f"{sorted(fsdp_param_info.param_indices.keys())} while " 1712 "the output_states has the param keys " 1713 f"{sorted(output_states.keys())}." 1714 ) 1715 return output_states 1716 else: 1717 return {} 1718 1719 1720def _convert_state_with_orig_params( 1721 all_optim_state_keys: List[_OptimStateKey], 1722 optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]], 1723 fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], 1724 optim_state_dict: Dict[Union[str, int], Any], 1725 to_save: bool, 1726 shard_state: bool, 1727 cpu_offload: bool = True, 1728) -> Dict[str, Any]: 1729 fsdp_osd_state: Dict[str, Any] = {} 1730 # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo 1731 # usually corresponds to multiple parameters. We could not use FSDPParamInfo 1732 # as the key because FSDPParamInfo is not hashable. As a result, we fall back 1733 # to `id(FSDPParamInfo)`, which the type is an integer. 1734 all_states: Dict[int, Dict[str, Any]] = {} 1735 # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers 1736 # across ranks 1737 for optim_state_key in all_optim_state_keys: 1738 param_key: Union[str, int, None] = optim_state_key_to_param_key.get( 1739 optim_state_key, None 1740 ) 1741 1742 if param_key is None and not optim_state_key.is_fsdp_managed: 1743 continue 1744 1745 if optim_state_key.is_fsdp_managed: 1746 fqn = optim_state_key.unflat_param_names[0] 1747 fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None) 1748 if fsdp_param_info is None: 1749 # This can happen if the not all FSDP instances have all the 1750 # parameters. This can happen with FSDP + some MPMD style 1751 # parallelism. 1752 1753 # TODO: it is unclear if we need to do the same check with 1754 # non-FSDP managed keys. 1755 continue 1756 state = {} if param_key is None else optim_state_dict[param_key] 1757 if id(fsdp_param_info) not in all_states: 1758 all_states[id(fsdp_param_info)] = {} 1759 all_states[id(fsdp_param_info)][fqn] = state 1760 1761 elif to_save: 1762 assert len(optim_state_key.unflat_param_names) == 1 1763 unflat_param_name = optim_state_key.unflat_param_names[0] 1764 with SimpleProfiler.profile("none_fsdp_managed_copy"): 1765 param_key = cast(Union[str, int], param_key) 1766 fsdp_osd_state[unflat_param_name] = copy.copy( 1767 optim_state_dict[param_key] 1768 ) 1769 if cpu_offload: 1770 for state_name, value in sorted_items( 1771 fsdp_osd_state[unflat_param_name] 1772 ): 1773 if not torch.is_tensor(value): 1774 continue 1775 fsdp_osd_state[unflat_param_name][state_name] = value.cpu() 1776 1777 # Instead of gathering the state of each parameter individually, we perform 1778 # the gathering all at once to speed up the process. 1779 for _all_states in all_states.values(): 1780 fqn = next(iter(_all_states.keys())) 1781 fsdp_param_info = fqn_to_fsdp_param_info[fqn] 1782 assert len(fsdp_param_info.param_requires_grad) > 0, ( 1783 "With use_orig_params, FSDPParamInfo should have requires_grad " 1784 "information. However, the length is zero." 1785 ) 1786 for key, idx in fsdp_param_info.param_indices.items(): 1787 if key in _all_states: 1788 continue 1789 if not fsdp_param_info.param_requires_grad[idx]: 1790 continue 1791 raise RuntimeError( 1792 f"{key} is not in the optimizer state. " 1793 "The FSDPParamInfo has the param keys " 1794 f"{sorted(fsdp_param_info.param_indices.keys())} while " 1795 "the optimizer has the param keys " 1796 f"{sorted(_all_states.keys())}." 1797 ) 1798 fsdp_osd_state.update( 1799 _gather_all_orig_param_state( 1800 fsdp_param_info, 1801 _all_states, 1802 shard_state, 1803 to_save, 1804 cpu_offload, 1805 ) 1806 ) 1807 1808 return fsdp_osd_state 1809 1810 1811def _convert_state_with_flat_params( 1812 all_optim_state_keys: List[_OptimStateKey], 1813 optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]], 1814 fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], 1815 optim_state_dict: Dict[Union[str, int], Any], 1816 to_save: bool, 1817 shard_state: bool, 1818 cpu_offload: bool = True, 1819) -> Dict[str, Any]: 1820 fsdp_osd_state: Dict[str, Any] = {} 1821 # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers 1822 # across ranks 1823 for optim_state_key in all_optim_state_keys: 1824 param_key: Union[str, int, None] = optim_state_key_to_param_key.get( 1825 optim_state_key, None 1826 ) 1827 1828 assert param_key is not None, ( 1829 "If use_orig_params is False, we must be able to find the " 1830 f"corresponding param id. {optim_state_key} {param_key}" 1831 ) 1832 1833 if optim_state_key.is_fsdp_managed: 1834 # If there are multiple unflat_param_names (not use_orig_params), 1835 # they share the same FSDPParamInfo. So the first unflat_param_name 1836 # is sufficient to fetch the FSDPParamInfo. 1837 fqn = optim_state_key.unflat_param_names[0] 1838 fsdp_param_info = fqn_to_fsdp_param_info[fqn] 1839 unflat_state = _unflatten_optim_state( 1840 fsdp_param_info, 1841 optim_state_dict[param_key], 1842 to_save, 1843 shard_state, 1844 cpu_offload, 1845 ) 1846 if to_save: 1847 assert len(unflat_state) == len(optim_state_key.unflat_param_names) 1848 for unflat_param_name, unflat_param_state in zip( 1849 optim_state_key.unflat_param_names, 1850 unflat_state, 1851 ): 1852 fsdp_osd_state[unflat_param_name] = unflat_param_state 1853 elif to_save: 1854 assert len(optim_state_key.unflat_param_names) == 1 1855 unflat_param_name = optim_state_key.unflat_param_names[0] 1856 fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key]) 1857 if cpu_offload: 1858 for state_name, value in sorted_items( 1859 fsdp_osd_state[unflat_param_name] 1860 ): 1861 if not torch.is_tensor(value): 1862 continue 1863 fsdp_osd_state[unflat_param_name][state_name] = value.cpu() 1864 1865 return fsdp_osd_state 1866 1867 1868@torch.no_grad() 1869def _optim_state_dict( 1870 model: nn.Module, 1871 optim: torch.optim.Optimizer, 1872 optim_state_dict: Dict[str, Any], 1873 optim_input: Optional[ 1874 Union[ 1875 List[Dict[str, Any]], 1876 Iterable[nn.Parameter], 1877 ] 1878 ], 1879 rank0_only: bool, 1880 shard_state: bool, 1881 group: Optional[dist.ProcessGroup], 1882 using_optim_input: bool, 1883 use_orig_params: bool = False, 1884 cpu_offload: bool = True, 1885) -> Dict[str, Any]: 1886 """ 1887 Consolidates the optimizer state and returns it as a :class:`dict` 1888 following the convention of :meth:`torch.optim.Optimizer.state_dict`, 1889 i.e. with keys ``"state"`` and ``"param_groups"``. 1890 The flat parameters in ``FSDP`` modules contained in ``model`` are mapped 1891 back to their unflattened parameters. 1892 1893 Parameter keys are not well-defined. For a regular optimizer, the optimizer 1894 state_dict contains a mapping from parameter IDs to parameter states. 1895 Parameter IDs are the order of parameters in ``optim.param_groups()`` across 1896 all the groups. This API also allows user to pass ``optim_input`` for the 1897 mapping between parameters and parameter IDs. Using ``optim_input`` is being 1898 deprecated. 1899 1900 If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not 1901 contain parameter IDs mapping but a mapping from parameter FQNs to parameter 1902 states. This API finds the mapping from FQNs to parameters if the optimizer 1903 is a ``NamedOptimizer``. 1904 1905 If ``use_orig_params`` is True, each rank will have all FSDP-managed 1906 parameters but some of these parameters may be empty due to the sharding. 1907 For a regular optim.Optimizer, states for those empty parameters will 1908 not be initialized. So, when aggregating the FQNs across ranks, no assert 1909 will be raised on a rank even if it does not have all the states -- it is 1910 valid and FSDP knows how to aggregate them. However, FSDP has to ignore 1911 handling those parameters that are not managed by FSDP and do not exist on 1912 the local rank -- those are managed by other parallelisms and FSDP does not 1913 know how to handle/aggregate them. 1914 1915 Args: 1916 model (nn.Module): Root module (which may or may not be a 1917 :class:`FullyShardedDataParallel` instance) whose parameters 1918 were passed into the optimizer ``optim``. 1919 optim (torch.optim.Optimizer): Optimizer for ``model`` 's 1920 parameters. 1921 rank0_only (bool): If ``True``, saves the populated :class:`dict` 1922 only on rank 0; if ``False``, saves it on all ranks. (Default: 1923 ``True``) 1924 shard_state (bool): If ``True``, shard and distribute all 1925 non-zero-dimension states. 1926 1927 Returns: 1928 Dict[str, Any]: A :class:`dict` containing the optimizer state for 1929 ``model`` 's original unflattened parameters and including keys 1930 "state" and "param_groups" following the convention of 1931 :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, 1932 then nonzero ranks return an empty :class:`dict`. 1933 """ 1934 SimpleProfiler.reset() 1935 cm = ExitStack() 1936 cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL)) 1937 _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model)) 1938 to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state 1939 1940 with SimpleProfiler.profile("preprocessing"): 1941 param_to_fqns = _get_param_to_fqns(model) 1942 flat_param_to_fqn = _get_flat_param_to_fqn(model) 1943 is_named_optimizer = _is_named_optimizer(optim_state_dict) 1944 1945 param_key_to_param = cast( 1946 Dict[Union[int, str], nn.Parameter], 1947 ( 1948 _get_param_id_to_param_from_optim_input(model, optim_input) 1949 if using_optim_input 1950 else _get_param_key_to_param( 1951 optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn 1952 ) 1953 ), 1954 ) 1955 fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) 1956 1957 with SimpleProfiler.profile("preprocessing_with_comm"): 1958 ( 1959 all_optim_state_keys, 1960 optim_state_key_to_param_key, 1961 ) = _map_param_key_to_optim_keys( 1962 optim_state_dict, 1963 group, 1964 param_key_to_param, 1965 param_to_fqns, 1966 fqn_to_fsdp_param_info, 1967 merge_keys=use_orig_params, 1968 ) 1969 1970 with SimpleProfiler.profile("state_converting"): 1971 convert_fn = ( 1972 _convert_state_with_orig_params 1973 if use_orig_params 1974 else _convert_state_with_flat_params 1975 ) 1976 fsdp_osd_state = convert_fn( 1977 all_optim_state_keys, 1978 optim_state_key_to_param_key, 1979 fqn_to_fsdp_param_info, 1980 optim_state_dict["state"], 1981 to_save, 1982 shard_state, 1983 cpu_offload, 1984 ) 1985 1986 # At this point, communication is complete and ranks can return early if nothing 1987 # will be saved on that rank. 1988 if not to_save: 1989 return {} 1990 1991 fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state} 1992 1993 flat_param_fqns = set(flat_param_to_fqn.values()) 1994 for key, value in optim_state_dict["state"].items(): 1995 if key in fsdp_osd_state: 1996 continue 1997 if key in flat_param_fqns: 1998 continue 1999 if key in param_key_to_param: 2000 continue 2001 # This key is not recognized by FSDP. It may be a user-defined state 2002 # or some parameters state that FSDP is unable to map from 2003 # ``optim.param_groups``. 2004 warnings.warn( 2005 f"Found a optim state, {key}, that FSDP cannot process. FSDP " 2006 "will directly copy everything to the returned state_dict. In " 2007 "most cases, this is a user-defined state that is not " 2008 "associated with any particular parameter. Another possible " 2009 "case is this state is managed by TorchRec. Otherwise, there may " 2010 " be a mismatched assumption of optim_state_dict of this mode." 2011 ) 2012 fsdp_osd_state[key] = value 2013 2014 if "param_groups" in optim_state_dict: 2015 fsdp_osd["param_groups"] = _unflatten_param_groups( 2016 optim_state_dict, param_key_to_param, param_to_fqns 2017 ) 2018 2019 cm.close() 2020 SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ") 2021 2022 return fsdp_osd 2023 2024 2025def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: 2026 """ 2027 Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` 2028 if the param is managed by FSDP. Shared parameters, or original parameters that 2029 are shared across multiple nn.Modules, are required to belong to one and only 2030 one FSDP instance and thus correspond to one ``FlatParameter``. Within the one 2031 ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared 2032 parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters. 2033 """ 2034 2035 def module_fn(module, prefix, tree_level, fqn_to_param_info): 2036 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 2037 if fsdp_state is None: 2038 return 2039 _lazy_init(fsdp_state, module) 2040 handle = _module_handle(fsdp_state, module) 2041 if not handle: 2042 return 2043 flat_param = handle.flat_param 2044 fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, []) 2045 # NOTE: `idx` indexes into the data structures *without* padding 2046 # elements 2047 for idx, local_fqn in enumerate(flat_param._fqns): 2048 fqn = clean_tensor_name(prefix + local_fqn) 2049 if fqn in fqn_to_param_info: 2050 assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn 2051 fqn_to_param_info[fqn] = fsdp_param_info 2052 fsdp_param_info.param_indices[fqn] = idx 2053 if flat_param._params is not None: 2054 fsdp_param_info.param_requires_grad.append( 2055 flat_param._params[idx].requires_grad 2056 ) 2057 2058 def return_fn(fqn_to_param_info): 2059 return fqn_to_param_info 2060 2061 fqn_to_param_info: Dict[str, FSDPParamInfo] = {} 2062 # FlatParameter._fqns stores the local fqn, starting from the root of the 2063 # FSDP. Using _apply_to_modules() with model (may not be the FSDP root 2064 # module) allows us to construct the global fqn. 2065 return _apply_to_modules( 2066 model, 2067 module_fn, 2068 return_fn, 2069 [fqn for fqn, _ in _named_parameters_with_duplicates(model)], 2070 fqn_to_param_info, 2071 ) 2072 2073 2074@no_type_check 2075def _set_optim_use_dtensor( 2076 fsdp_state: _FSDPState, 2077 state_dict_settings: StateDictSettings, 2078) -> None: 2079 # If device_mesh is passed in when initalizing FSDP, we automatically turn the 2080 # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type 2081 # has to be set to SHARDED_STATE_DICT. 2082 if getattr(fsdp_state, "_device_mesh", None): 2083 state_dict_type = state_dict_settings.state_dict_type 2084 if state_dict_type == StateDictType.LOCAL_STATE_DICT: 2085 raise RuntimeError( 2086 "Found state_dict_type LOCAL_STATE_DICT.", 2087 "DeviceMesh is not compatible with LOCAL_STATE_DICT.", 2088 "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", 2089 ) 2090 else: 2091 state_dict_settings.optim_state_dict_config._use_dtensor = True 2092