1# mypy: allow-untyped-defs 2import contextlib 3import functools 4import gc 5import warnings 6from dataclasses import asdict, dataclass, field 7from itertools import chain 8from typing import ( 9 Any, 10 Callable, 11 cast, 12 Dict, 13 Generator, 14 Iterable, 15 List, 16 no_type_check, 17 Optional, 18 Set, 19 Tuple, 20 Union, 21) 22 23import torch 24import torch.distributed as dist 25import torch.nn as nn 26from torch.distributed._shard.sharded_tensor import ShardedTensor 27from torch.distributed._state_dict_utils import ( 28 _broadcast_state_dict, 29 _distribute_state_dict, 30 _flatten_state_dict, 31 _gather_state_dict, 32 _offload_state_dict_to_cpu, 33 _unflatten_state_dict, 34) 35from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 36 _CHECKPOINT_PREFIX, 37) 38from torch.distributed.fsdp import ( 39 FullOptimStateDictConfig, 40 FullStateDictConfig, 41 FullyShardedDataParallel as FSDP, 42 OptimStateDictConfig, 43 ShardedOptimStateDictConfig, 44 ShardedStateDictConfig, 45 StateDictConfig, 46 StateDictType, 47) 48from torch.distributed.fsdp._common_utils import ( 49 _get_module_fsdp_state_if_fully_sharded_module, 50 FSDP_WRAPPED_MODULE, 51) 52from torch.distributed.tensor import DTensor 53from torch.nn.modules.module import _IncompatibleKeys 54from torch.nn.parallel import DistributedDataParallel as DDP 55from torch.utils._pytree import tree_map_only 56 57 58__all__ = [ 59 "FQNS_T", 60 "PrimitiveType", 61 "ValueType", 62 "DictValueType", 63 "ListDictValueType", 64 "OptimizerStateType", 65 "StateDictOptions", 66 "get_model_state_dict", 67 "get_optimizer_state_dict", 68 "get_state_dict", 69 "set_model_state_dict", 70 "set_optimizer_state_dict", 71 "set_state_dict", 72] 73 74 75_FLAT_PARAM = "_flat_param" 76_PG = "param_groups" 77_PARAMS = "params" 78_STATE = "state" 79 80FQNS_T = Set[str] 81PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] 82ValueType = Union[ 83 PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] 84] 85DictValueType = Dict[str, ValueType] 86ListDictValueType = List[DictValueType] 87OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] 88 89 90_patched_state_dict: Set[Callable] = set() 91 92 93@contextlib.contextmanager 94def _gc_context(): 95 is_enabled = gc.isenabled() 96 gc.disable() 97 try: 98 yield 99 finally: 100 if is_enabled: 101 gc.enable() 102 103 104@dataclass 105class StateDictOptions: 106 """ 107 This dataclass specifies how get_state_dict/set_state_dict will work. 108 109 - ``full_state_dict``: if this is set to True, all the tensors in the 110 returned state_dict will be gathered. No ShardedTensor and DTensor 111 will be in the returned state_dict. 112 113 - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if 114 ``full_state_dict`` is also true, then only the rank0 will get the 115 state_dict and all other ranks will get empty state_dict. 116 117 - ``ignore_frozen_params``: if the value is True, the returned state_dict 118 won't contain any frozen parameters -- the ``requires_grad`` is False. 119 The default value is False. 120 121 - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option 122 indicates whether to keep the submodule prefixes from the state_dict keys. 123 or example, if the submodule is ``module.pretrain`` and the full FQN of 124 the parameter is ``pretrain.layer1.weight`` of the param. When this option 125 is True, the parameter's key in the returned state_dict will be 126 ``pretrain.layer1.weight``. If the options is False, the key will be 127 ``layer1.weight``. 128 Note that if ``keep_submodule_prefixes`` is False, there may be conflicted 129 FQNs, hence there should be only one submodule in ``submodules``. 130 131 - ``strict``: the ``strict`` option when ``set_state_dict`` calls 132 model.load_state_dict(). 133 134 - ``broadcast_from_rank0``: when the option is True, rank0 should receive a 135 full state_dict and will broadcast the tensors in the state_dict/ 136 optim_state_dict one by one to other ranks. Other ranks will receive 137 the tensors and shard according to the local shards in the model and 138 optimizer. ``full_state_dict`` must be set to True when using this option. 139 This option currently only supports DTensor, not the legacy ShardedTensor. 140 """ 141 142 full_state_dict: bool = False 143 cpu_offload: bool = False 144 ignore_frozen_params: bool = False 145 keep_submodule_prefixes: bool = True 146 strict: bool = True 147 broadcast_from_rank0: bool = False 148 flatten_optimizer_state_dict: bool = False 149 150 151@dataclass 152class _StateDictInfo(StateDictOptions): 153 fqn_param_mapping: Dict[ 154 Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] 155 ] = field(default_factory=dict) 156 shared_params_mapping: Dict[ 157 Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] 158 ] = field(default_factory=dict) 159 submodule_prefixes: Set[str] = field(default_factory=set) 160 handle_model: bool = True 161 handle_optim: bool = True 162 fsdp_context: Callable = contextlib.nullcontext 163 fsdp_modules: List[nn.Module] = field(default_factory=list) 164 165 166@functools.lru_cache(maxsize=None) 167def _get_fqns( 168 model: nn.Module, 169 name: str, 170 skip_ddp_prefix: bool = True, 171 skip_compiler_prefix: bool = True, 172) -> FQNS_T: 173 """ 174 This API is used to convert the name of a parameter to the FQNs. For FSDP 175 without `use_orig_params`, the name of FlatParameter can be mapped to 176 multiple original parameters. As a result, the return type of this function 177 is `Set[str]`. 178 179 Args: 180 module (nn.Module): the root model. 181 name (str): the name 182 skip_ddp_prefix (bool): whether to skip DDP's `module` prefix 183 184 Returns: 185 The canonical FQNs based on the model traversal. 186 """ 187 188 # Remove the checkpoint prefix, if it exists. 189 name = name.replace(_CHECKPOINT_PREFIX, "") 190 if "." not in name: 191 return {name} 192 193 obj_names = name.split(".") 194 fqn_obj_names = [] 195 curr_obj = model 196 for i, curr_obj_name in enumerate(obj_names): 197 if isinstance(curr_obj, DDP): 198 assert curr_obj_name == "module" 199 curr_obj = curr_obj.module 200 if not skip_ddp_prefix: 201 fqn_obj_names.append(curr_obj_name) 202 elif isinstance(curr_obj, FSDP): 203 if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM: 204 prefix = ".".join(fqn_obj_names) 205 flat_param = getattr(curr_obj, _FLAT_PARAM) 206 if prefix: 207 prefix = f"{prefix}." 208 return {f"{prefix}{fqn}" for fqn in flat_param._fqns} 209 curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) 210 if curr_obj_name != FSDP_WRAPPED_MODULE: 211 fqn_obj_names.append(curr_obj_name) 212 curr_obj = getattr(curr_obj, curr_obj_name) 213 elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): 214 assert curr_obj_name == "_orig_mod" 215 curr_obj = curr_obj._orig_mod 216 if not skip_compiler_prefix: 217 fqn_obj_names.append(curr_obj_name) 218 else: 219 fqn_obj_names.append(curr_obj_name) 220 if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: 221 if i != len(obj_names) - 1: 222 raise RuntimeError("Expect `_extra_state` to be the last obj name") 223 else: 224 curr_obj = getattr(curr_obj, curr_obj_name) 225 226 return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")} 227 228 229class _EXTRA_STATE: 230 pass 231 232 233def _iterate_valid_model_state(model): 234 visited_modules: Set[nn.Module] = set() 235 236 def recurse(module: nn.Module, curr_fqn: str) -> Generator: 237 visited_modules.add(module) 238 239 curr_fqn = f"{curr_fqn}." if curr_fqn else "" 240 for name, submodule in module.named_children(): 241 if submodule in visited_modules: 242 continue 243 new_fqn = f"{curr_fqn}{name}" 244 yield from recurse(submodule, new_fqn) 245 246 for name, obj in chain( 247 module.named_buffers(recurse=False), module.named_parameters(recurse=False) 248 ): 249 if name in module._non_persistent_buffers_set: 250 continue 251 new_fqn = f"{curr_fqn}{name}" 252 yield new_fqn, obj 253 254 if ( 255 getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state) 256 != nn.Module.get_extra_state 257 ): 258 new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}" 259 yield new_fqn, _EXTRA_STATE() 260 261 yield from recurse(model, "") 262 263 264def _verify_options( 265 model: nn.Module, 266 optims: Tuple[torch.optim.Optimizer, ...], 267 optim_only: bool, 268 *, 269 submodules: Optional[Set[nn.Module]] = None, 270 options: Optional[StateDictOptions] = None, 271) -> _StateDictInfo: 272 """ 273 Verify the model and options passed by the user and generates _StateDictInfo. 274 """ 275 if submodules: 276 warnings.warn( 277 "Getting submodules only model/optim state_dict is deprecated and " 278 "will be removed in 2.5. This feature can be achieved by manually " 279 "filtering out the state_dict returned from get_state_dict.", 280 FutureWarning, 281 ) 282 if optim_only and not optims: 283 raise RuntimeError( 284 "Optimizers are not passed in but optim_only is set to True." 285 ) 286 287 options = options or StateDictOptions() 288 289 fqn_param_mapping: Dict[ 290 Union[str, torch.Tensor], Union[Set[str], torch.Tensor] 291 ] = {} 292 shared_params_mapping: Dict[ 293 Union[str, torch.Tensor], Union[Set[str], torch.Tensor] 294 ] = {} 295 for name, param in _iterate_valid_model_state(model): 296 if isinstance(param, _EXTRA_STATE): 297 continue 298 299 fqns = _get_fqns(model, name) 300 fqn = fqn_param_mapping.get(param, None) 301 if fqn is not None: 302 cast(Set[str], fqn_param_mapping[param]).update(fqns) 303 shared_params_mapping[param] = fqn_param_mapping[param] 304 else: 305 # We need to do copy as _get_fqns is lru_cached 306 fqn_param_mapping[param] = fqns.copy() 307 for fqn in fqns: 308 if not isinstance(param, _EXTRA_STATE): 309 fqn_param_mapping[fqn] = param 310 311 for param_, fqns_ in list(shared_params_mapping.items()): 312 for fqn in fqns_: 313 shared_params_mapping[fqn] = cast(torch.Tensor, param_) 314 315 submodule_prefixes: Set[str] = set() 316 if submodules: 317 submodules = set(submodules) 318 for name, module in model.named_modules(): 319 if module not in submodules: 320 continue 321 fqns = _get_fqns(model, name) 322 assert len(fqns) == 1, "Submodule FQN should only have 1 instance" 323 submodule_prefixes.update(f"{fqn}." for fqn in fqns) 324 325 if options.broadcast_from_rank0 and not options.full_state_dict: 326 raise ValueError( 327 "full_state_dict must be True when broadcast_from_rank0 is True." 328 ) 329 fsdp_modules = FSDP.fsdp_modules(model) 330 state_dict_config: StateDictConfig 331 optim_state_dict_config: OptimStateDictConfig 332 fsdp_context: Callable 333 if fsdp_modules: 334 # FSDP API only work if at least one FSDP instance exists. 335 if options.full_state_dict: 336 state_dict_config = FullStateDictConfig( 337 offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload 338 ) 339 optim_state_dict_config = FullOptimStateDictConfig( 340 offload_to_cpu=options.cpu_offload, 341 rank0_only=(options.cpu_offload or options.broadcast_from_rank0), 342 ) 343 state_dict_type = StateDictType.FULL_STATE_DICT 344 else: 345 state_dict_config = ShardedStateDictConfig( 346 offload_to_cpu=options.cpu_offload, 347 ) 348 optim_state_dict_config = ShardedOptimStateDictConfig( 349 offload_to_cpu=options.cpu_offload, 350 ) 351 state_dict_type = StateDictType.SHARDED_STATE_DICT 352 353 @contextlib.contextmanager 354 def fsdp_state_dict_type_without_warning( 355 module, 356 state_dict_type, 357 state_dict_config, 358 optim_state_dict_config, 359 ): 360 with warnings.catch_warnings(): 361 warnings.filterwarnings( 362 "ignore", message="FSDP.state_dict_type", category=FutureWarning 363 ) 364 with FSDP.state_dict_type( 365 module=module, 366 state_dict_type=state_dict_type, 367 state_dict_config=state_dict_config, 368 optim_state_dict_config=optim_state_dict_config, 369 ): 370 yield 371 372 fsdp_context = functools.partial( 373 fsdp_state_dict_type_without_warning, 374 module=model, 375 state_dict_type=state_dict_type, 376 state_dict_config=state_dict_config, 377 optim_state_dict_config=optim_state_dict_config, 378 ) 379 else: 380 fsdp_context = contextlib.nullcontext 381 382 return _StateDictInfo( 383 **asdict(options), 384 fqn_param_mapping=fqn_param_mapping, 385 shared_params_mapping=shared_params_mapping, 386 submodule_prefixes=submodule_prefixes, 387 fsdp_context=fsdp_context, 388 fsdp_modules=cast(List[nn.Module], fsdp_modules), 389 handle_model=not optim_only, 390 handle_optim=(len(optims) > 0), 391 ) 392 393 394def _verify_state_dict( 395 model_state_dict: Dict[str, ValueType], 396 optim_state_dict: OptimizerStateType, 397 info: _StateDictInfo, 398) -> None: 399 for module in info.fsdp_modules: 400 fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) 401 assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module." 402 403 # Verify if the model_state_dict and optim_state_dict are valid. This API 404 # should give the users an explicit error message to debug or report. 405 if ( 406 info.handle_model 407 and not model_state_dict 408 and not info.submodule_prefixes 409 and not info.ignore_frozen_params 410 and not (info.cpu_offload and info.full_state_dict) 411 and info.strict 412 and not info.broadcast_from_rank0 413 ): 414 raise RuntimeError( 415 "The option indicates that model state_dict is required to save " 416 "or load, but model state_dict is empty." 417 f"rank = {dist.get_rank()=}." 418 ) 419 420 if info.handle_optim: 421 if ( 422 not optim_state_dict 423 and not (info.cpu_offload and info.full_state_dict) 424 and (not info.broadcast_from_rank0) 425 ): 426 raise RuntimeError( 427 "The option indicates that model state_dict is required to save, " 428 f"or load but optim state_dict is empty. {optim_state_dict}" 429 ) 430 431 for key in model_state_dict.keys(): 432 if _FLAT_PARAM in key: 433 raise RuntimeError( 434 f"{key} contains {_FLAT_PARAM}. This can happen if the model " 435 "is not the root module." 436 ) 437 438 439def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable: 440 call = getattr(obj, api) 441 if call in _patched_state_dict: 442 call = functools.partial(getattr(obj.__class__, api), self=obj) 443 return call 444 445 446def _maybe_full_or_cpu_state_dict( 447 state_dict: Dict[str, Any], info: _StateDictInfo 448) -> Dict[str, Any]: 449 if info.full_state_dict: 450 ranks_only = ( 451 () 452 if (not info.cpu_offload or not torch.distributed.is_initialized()) 453 else (0,) 454 ) 455 return _gather_state_dict( 456 state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only 457 ) 458 elif info.cpu_offload: 459 return _offload_state_dict_to_cpu(state_dict) 460 else: 461 return state_dict 462 463 464@torch.no_grad() 465def _get_model_state_dict( 466 model: nn.Module, info: _StateDictInfo 467) -> Dict[str, ValueType]: 468 if not info.handle_model: 469 return {} 470 471 with info.fsdp_context(): 472 state_dict = _state_dict_fn(model, "state_dict")() 473 474 for key in list(state_dict.keys()): 475 fqns = _get_fqns(model, key) 476 assert len(fqns) == 1, (key, fqns) 477 fqn = next(iter(fqns)) 478 if fqn != key: 479 # As we only support FSDP, DDP, and TP, the only cases are 480 # wrapper-based DDP and compiler. Verify if the assumption 481 # is correct. 482 def verify(key, fqn) -> bool: 483 if len(fqn) >= len(key): 484 return False 485 fqn_split = fqn.split(".") 486 key_split = key.split(".") 487 fqn_idx = 0 488 for key_idx, key_name in enumerate(key_split): 489 if key_name == fqn_split[fqn_idx]: 490 fqn_idx += 1 491 if fqn_idx == len(fqn_split): 492 return key_idx == len(key_split) - 1 493 elif key_name in ("module", "_orig_mod"): 494 continue 495 else: 496 return False 497 return True 498 499 if not verify(key, fqn): 500 raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}") 501 state_dict[fqn] = state_dict.pop(key) 502 503 if info.submodule_prefixes: 504 new_state_dict: Dict[str, ValueType] = {} 505 # TODO: make this faster. 506 for fqn in state_dict.keys(): 507 for prefix in info.submodule_prefixes: 508 if not fqn.startswith(prefix): 509 continue 510 if info.keep_submodule_prefixes: 511 new_state_dict[fqn] = state_dict[fqn] 512 else: 513 new_fqn = fqn[len(prefix) :] 514 new_state_dict[new_fqn] = state_dict[fqn] 515 state_dict = new_state_dict 516 517 if info.ignore_frozen_params: 518 for key, param in model.named_parameters(): 519 if param.requires_grad: 520 continue 521 fqns = _get_fqns(model, key) 522 for fqn in fqns: 523 state_dict.pop(fqn) 524 525 for key, p in list(state_dict.items()): 526 if torch.is_tensor(p) and p.is_meta: 527 state_dict.pop(key) 528 529 return _maybe_full_or_cpu_state_dict(state_dict, info) 530 531 532@torch.no_grad() 533def _load_model_state_dict( 534 model: nn.Module, 535 state_dict: Dict[str, ValueType], 536 info: _StateDictInfo, 537) -> _IncompatibleKeys: 538 if not info.handle_model or (not state_dict and not info.broadcast_from_rank0): 539 return _IncompatibleKeys({}, {}) 540 541 local_state_dict = {} 542 for key, value in _iterate_valid_model_state(model): 543 fqns = _get_fqns(model, key) 544 fqns_with_prefix = _get_fqns( 545 model, key, skip_ddp_prefix=False, skip_compiler_prefix=False 546 ) 547 548 for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix): 549 if ( 550 not info.broadcast_from_rank0 or dist.get_rank() == 0 551 ) and fqn != fqn_with_prefix: 552 state_dict[fqn_with_prefix] = state_dict.pop(fqn) 553 local_state_dict[fqn_with_prefix] = value 554 555 assign = False 556 if info.broadcast_from_rank0 or info.full_state_dict: 557 device = None 558 for key, value in local_state_dict.items(): 559 if torch.is_tensor(value) and value.dim() > 0: 560 if device is None: 561 device = value.device 562 else: 563 assert device == value.device 564 assert device is not None 565 if device == torch.device("meta"): 566 device = dist.distributed_c10d._get_pg_default_device() 567 assign = True 568 if info.broadcast_from_rank0: 569 _broadcast_state_dict( 570 state_dict, local_state_dict, device=device, strict=info.strict 571 ) 572 elif info.full_state_dict: 573 _distribute_state_dict(state_dict, local_state_dict, device=device) 574 for fqn, local_state in local_state_dict.items(): 575 state_dict[fqn] = local_state 576 577 with info.fsdp_context(): 578 return cast( 579 _IncompatibleKeys, 580 _state_dict_fn(model, "load_state_dict")( 581 state_dict=state_dict, strict=info.strict, assign=assign 582 ), 583 ) 584 585 586def _init_optim_state(optim: torch.optim.Optimizer) -> None: 587 """ 588 Initialize optim states by calling the step() with zero grads. 589 """ 590 if optim.state: 591 # The optimizer state is initialized. 592 return 593 594 # There are some stateless optimizers like SGD. These optimizer will 595 # not return in the above condition. So if gradients exist, we should also 596 # return. If gradients do not exist, the following initialization should 597 # not disturb SGD because the gradients and lr are both zero. 598 for param_group in optim.param_groups: 599 for param in param_group[_PARAMS]: 600 if param.grad is not None: 601 return 602 603 for param_group in optim.param_groups: 604 for param in param_group[_PARAMS]: 605 if param.requires_grad: 606 param.grad = torch.zeros_like(param) 607 608 # Some optimizers will update parameters regardless of grads due to lr, so 609 # make lr to zero when calling `step()`. 610 lrs = [] 611 for param_group in optim.param_groups: 612 if "lr" in param_group: 613 lrs.append(param_group["lr"]) 614 param_group["lr"] = 0.0 615 optim.step(closure=None) 616 # Whether to recover the "lr" should not matter too much as we will 617 # restore checkpointing later. 618 for param_group in optim.param_groups: 619 if "lr" in param_group: 620 param_group["lr"] = lrs.pop(0) 621 optim.zero_grad(set_to_none=True) 622 623 624def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]: 625 """ 626 This API flattens the optimizer state_dict to support optimizer resharding for 627 MPMD, e.g., pipeline parallelism. 628 629 Without the API, the original optimizer state_dict looks like: 630 { 631 "state": { 632 "layer1.weight": { 633 "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor 634 }, 635 "layer2.weight": { 636 "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor 637 }, 638 }, 639 "param_group": [ 640 { 641 "lr": 0.0, 642 "betas": (0.9, 0.95), ..., 643 "params": ["layer1.weight", "layer2.weight"] 644 } 645 ] 646 } 647 648 With this API, the optimizer state_dict looks like: 649 { 650 "state.layer1.weight.step": 10, 651 "state.layer2.weight.step": 10, 652 "state.layer1.weight.exp_avg": SomeTensor, 653 "state.layer2.weight.exp_avg": SomeTensor, 654 "state.layer1.weight.exp_avg_sq": SomeTensor, 655 "state.layer2.weight.exp_avg_sq": SomeTensor, 656 "param_group.layer1.weight.lr" : 0.1, 657 "param_group.layer2.weight.lr" : 0.1, 658 "param_group.layer1.weight.betas" : (0.9, 0.95), 659 "param_group.layer2.weight.betas" : (0.9, 0.95), 660 } 661 662 Note that if any of the value is a container, like the betas in the example, 663 this API won't flattent it. 664 """ 665 666 def _raise_if_type_not_supported(v): 667 if not isinstance(v, (torch.Tensor, int, float)): 668 raise NotImplementedError( 669 "Flattening optimizer state_dict only supports " 670 "tensor, int, float states now. " 671 f"Type is {type(v)}." 672 ) 673 674 ret: Dict[str, ValueType] = {} 675 for fqn, state in cast(DictValueType, state_dict[_STATE]).items(): 676 for k, v in cast(DictValueType, state).items(): 677 _raise_if_type_not_supported(v) 678 ret[f"{_STATE}.{fqn}.{k}"] = v 679 680 for param_group in cast(ListDictValueType, state_dict[_PG]): 681 fqns = param_group.pop(_PARAMS) 682 for fqn in cast(List[str], fqns): 683 for k, v in param_group.items(): 684 ret[f"{_PG}.{fqn}.{k}"] = v 685 return ret 686 687 688def _unflatten_optim_state_dict( 689 optim: torch.optim.Optimizer, 690 state_dict: Dict[str, ValueType], 691 info: _StateDictInfo, 692) -> OptimizerStateType: 693 """ 694 This API unflattens the state_dict generated by _flatten_optim_state_dict(). 695 See the docstring of _flatten_optim_state_dict() for more detail. 696 """ 697 state: DictValueType = {} 698 pg_state: ListDictValueType = [] 699 return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} 700 701 for param_group in optim.param_groups: 702 pg_state.append({_PARAMS: []}) 703 for param in param_group[_PARAMS]: 704 for fqn in info.fqn_param_mapping[param]: 705 params = pg_state[-1][_PARAMS] 706 assert isinstance(params, list) # typing 707 params.append(fqn) 708 if not param.requires_grad: 709 continue 710 state[fqn] = {} 711 for state_name in optim.state[param].keys(): 712 cast(DictValueType, state[fqn])[state_name] = state_dict[ 713 f"{_STATE}.{fqn}.{state_name}" 714 ] 715 716 first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0] 717 for k in param_group.keys(): 718 if k == _PARAMS: 719 continue 720 value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] 721 if k not in pg_state[-1]: 722 pg_state[-1][k] = value 723 elif pg_state[-1][k] != value: 724 raise RuntimeError( 725 "All the parameters in the same parameter group should have " 726 f"the same saved param_group value. But {first_param_fqn}.{k} " 727 f"is {value} while other(s) is {pg_state[-1][k]}." 728 ) 729 730 return return_osd 731 732 733@torch.no_grad() 734def _get_optim_state_dict( 735 model: nn.Module, 736 optimizers: Tuple[torch.optim.Optimizer, ...], 737 info: _StateDictInfo, 738) -> OptimizerStateType: 739 if not info.handle_optim: 740 return {} 741 742 optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []} 743 for optim in optimizers: 744 _init_optim_state(optim) 745 osd = _state_dict_fn(optim, "state_dict")() 746 if info.fsdp_modules: 747 with info.fsdp_context(): 748 osd = FSDP.optim_state_dict(model, optim, osd) 749 750 # We need to specially handle FlatParameter FSDP as 751 # FlatParameter FSDP converts the FQNs. 752 # There are no easy ways to do this conversion systematically. 753 # We can only use a string replacment without correctness check. 754 if not osd: 755 continue 756 for k in list(osd[_STATE].keys()): 757 if "_orig_mod" in k: 758 osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k) 759 for g in osd[_PG]: 760 params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]] 761 g[_PARAMS] = params 762 else: 763 params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups)) 764 param_pid_mapping = dict(zip(params, range(len(params)))) 765 fqn_pid_mapping = {} 766 for key, param in model.named_parameters(): 767 fqns = _get_fqns(model, key) 768 assert len(fqns) == 1 769 fqn = next(iter(fqns)) 770 if param not in param_pid_mapping: 771 continue 772 pid = param_pid_mapping[param] 773 fqn_pid_mapping[fqn] = pid 774 fqn_pid_mapping[pid] = fqn 775 776 for key in list(osd[_STATE].keys()): 777 fqn = fqn_pid_mapping[key] 778 osd[_STATE][fqn] = osd[_STATE].pop(key) 779 780 for group in osd[_PG]: 781 group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]] 782 783 if not osd: 784 continue 785 786 cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE]) 787 cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG]) 788 789 if info.flatten_optimizer_state_dict: 790 optim_state_dict = cast( 791 OptimizerStateType, _flatten_optim_state_dict(optim_state_dict) 792 ) 793 794 return _maybe_full_or_cpu_state_dict(optim_state_dict, info) 795 796 797def _split_optim_state_dict( 798 model: nn.Module, 799 optim: torch.optim.Optimizer, 800 optim_state_dict: OptimizerStateType, 801 info: _StateDictInfo, 802) -> OptimizerStateType: 803 """ 804 Extract the corresponding optim state_dict from ``optim_state_dict`` for 805 ``optim`` and return the result optim state_dict. 806 807 Args: 808 model (nn.Module): the root model. 809 optim (torch.optim.Optimizer): the optimizer. 810 optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that 811 contains the optim state_dict of ``optim``. 812 info (_StateDictInfo): state dict information. 813 814 Returns: 815 The optim state_dict of ``optim``. 816 """ 817 818 state: DictValueType = {} 819 pg_state: ListDictValueType = [] 820 return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} 821 pg_mapping: Dict[int, int] = {} 822 823 if all( 824 isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() 825 ): 826 return optim_state_dict 827 828 for param_group in optim.param_groups: 829 pg_state.append({_PARAMS: []}) 830 for param in param_group[_PARAMS]: 831 for fqn in info.fqn_param_mapping[param]: 832 if fqn in info.shared_params_mapping: 833 in_params = False 834 for loaded_param_group in cast( 835 ListDictValueType, optim_state_dict[_PG] 836 ): 837 if fqn in cast(List[str], loaded_param_group[_PARAMS]): 838 in_params = True 839 break 840 else: 841 in_params = True 842 if not in_params: 843 continue 844 845 params = pg_state[-1][_PARAMS] 846 assert isinstance(params, list) 847 params.append(fqn) 848 if param.requires_grad: 849 state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] 850 for loaded_param_group in cast( 851 ListDictValueType, optim_state_dict[_PG] 852 ): 853 if fqn in cast(List[str], loaded_param_group[_PARAMS]): 854 pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1 855 856 for param_group in cast(ListDictValueType, optim_state_dict[_PG]): 857 idx = pg_mapping.get(id(param_group), -1) 858 if idx == -1: 859 continue 860 for key, value in param_group.items(): 861 if key == _PARAMS: 862 continue 863 # TODO: check if value is the same if exists. 864 pg_state[idx][key] = value 865 866 return return_osd 867 868 869@torch.no_grad() 870def _load_optim_state_dict( 871 model: nn.Module, 872 optimizers: Tuple[torch.optim.Optimizer, ...], 873 state_dict: OptimizerStateType, 874 info: _StateDictInfo, 875) -> None: 876 if not info.handle_optim: 877 return 878 879 for optim in optimizers: 880 _init_optim_state(optim) 881 if state_dict: 882 if _STATE in state_dict: 883 optim_state_dict = _split_optim_state_dict( 884 model, optim, state_dict, info 885 ) 886 else: 887 optim_state_dict = _unflatten_optim_state_dict( 888 optim, cast(Dict[str, ValueType], state_dict), info 889 ) 890 else: 891 optim_state_dict = {} 892 if info.fsdp_modules: 893 # We need to specially handle FlatParameter FSDP as 894 # FlatParameter FSDP converts the FQNs. 895 for original_fqn, _ in model.named_parameters(): 896 fqns = _get_fqns(model, original_fqn) 897 fqns_with_compiler = _get_fqns( 898 model, original_fqn, skip_compiler_prefix=False 899 ) 900 if fqns == fqns_with_compiler: 901 continue 902 903 assert len(fqns) == 1 904 fqn = fqns.pop() 905 fqn_with_compiler = fqns_with_compiler.pop() 906 for g in optim_state_dict[_PG]: 907 val = cast(Dict[str, Any], g) 908 params = [ 909 key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS] 910 ] 911 val[_PARAMS] = params 912 osd_state = cast(DictValueType, optim_state_dict[_STATE]) 913 for k in list(osd_state.keys()): 914 if fqn in k: 915 osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k) 916 917 with info.fsdp_context(): 918 optim_state_dict = FSDP.optim_state_dict_to_load( 919 model, optim, optim_state_dict 920 ) 921 elif info.full_state_dict: 922 info.full_state_dict = False 923 local_state_dict = _get_optim_state_dict(model, (optim,), info) 924 info.full_state_dict = True 925 device = None 926 927 def _device(t): 928 if t.dim() > 0: 929 nonlocal device 930 if device is None: 931 device = t.device 932 elif device != t.device: 933 raise ValueError("Device mismatch") 934 return t 935 936 _ = tree_map_only(torch.Tensor, _device, local_state_dict) 937 assert device is not None 938 flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict) 939 flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict) 940 if info.broadcast_from_rank0: 941 _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device) 942 else: 943 _distribute_state_dict(flatten_osd, flatten_local_osd, device=device) 944 # The modifications listed seek to address the problem where optim might possess 945 # dissimilar parameters in comparison to optim_state_dict. This is achieved by 946 # incorporating differential parameters within local, which may result in optim 947 # having additional parameters ultimately. 948 for optim_key in flatten_osd.keys(): 949 if optim_key not in flatten_local_osd: 950 assert optim_key in osd_mapping 951 flatten_local_osd[optim_key] = flatten_osd[optim_key] 952 local_osd_mapping[optim_key] = osd_mapping[optim_key] 953 optim_state_dict = _unflatten_state_dict( 954 flatten_local_osd, local_osd_mapping 955 ) 956 957 # Note that we do not have to convert the FQN back to param id here if 958 # order in optim.param_groups[idx][_PARAMS] is the same as the one in 959 # optim_state_dict[_PG][idx][_PARAMS]. 960 _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict) 961 962 963def get_model_state_dict( 964 model: nn.Module, 965 *, 966 submodules: Optional[Set[nn.Module]] = None, 967 options: Optional[StateDictOptions] = None, 968) -> Dict[str, ValueType]: 969 """ 970 Return the model state_dict of ``model``. 971 972 See ``get_state_dict`` for the detail usage. 973 974 Args: 975 model (nn.Module): the nn.Module to the model. 976 submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters 977 that belong to the submodules. 978 options (StateDictOptions): the options to control how 979 model state_dict and optimizer state_dict should be returned. See 980 `StateDictOptions` for the details. 981 982 Returns: 983 The state_dict for ``model``. 984 985 :rtype: typing.Dict[str, ValueType] 986 """ 987 with _gc_context(): 988 info = _verify_options( 989 model, 990 (), 991 optim_only=False, 992 submodules=submodules, 993 options=options, 994 ) 995 model_state_dict = _get_model_state_dict(model, info) 996 _verify_state_dict(model_state_dict, {}, info) 997 return model_state_dict 998 999 1000def get_optimizer_state_dict( 1001 model: nn.Module, 1002 optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], 1003 *, 1004 submodules: Optional[Set[nn.Module]] = None, 1005 options: Optional[StateDictOptions] = None, 1006) -> OptimizerStateType: 1007 """ 1008 Return the combined state_dict for optimizers. 1009 1010 See ``get_state_dict`` for the detail usage. 1011 1012 Args: 1013 model (nn.Module): the nn.Module to the model. 1014 optimizers (Union[None, Optimizer, Iterable[Optimizer]]): 1015 The optimizers that are used to optimize ``model``. 1016 submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters 1017 that belong to the submodules. 1018 options (StateDictOptions): the options to control how 1019 model state_dict and optimizer state_dict should be returned. See 1020 `StateDictOptions` for the details. 1021 1022 Returns: 1023 The state_dict for ``optimizers``. 1024 1025 :rtype: OptimizerStateType 1026 """ 1027 with _gc_context(): 1028 optimizers = ( 1029 (optimizers,) 1030 if isinstance(optimizers, torch.optim.Optimizer) 1031 else tuple(optimizers) 1032 ) 1033 info = _verify_options( 1034 model, 1035 optimizers, 1036 optim_only=True, 1037 submodules=submodules, 1038 options=options, 1039 ) 1040 optim_state_dict = _get_optim_state_dict(model, optimizers, info) 1041 _verify_state_dict({}, optim_state_dict, info) 1042 return optim_state_dict 1043 1044 1045def get_state_dict( 1046 model: nn.Module, 1047 optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], 1048 *, 1049 submodules: Optional[Set[nn.Module]] = None, 1050 options: Optional[StateDictOptions] = None, 1051) -> Tuple[Dict[str, ValueType], OptimizerStateType]: 1052 """ 1053 Return the model state_dict and optimizers state_dict. 1054 1055 ``get_state_dict`` can process any module that is parallelized by PyTorch 1056 FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any 1057 combination of these parallelisms. The main functions of ``get_state_dict`` 1058 are: 1.) returning a model and optimizer state_dict that can be resharded 1059 with a different number of trainers and/or different parallelisms. 1060 2.) hiding the parallelism-specific state_dict APIs. Users don't have to call 1061 these APIs. 1062 3.) sanity checking the result state_dict. 1063 1064 The keys of the result state dictionary are the canonical FQNs (Fully 1065 Qualified Names). A canonical FQN refers to the FQN based on a parameter's 1066 position in an nn.Module hierarchy. More specifically, a canonical FQN to a 1067 parameter is the FQN returned by ``module.named_parameters()`` or 1068 ``module.named_buffers()`` when the module is not distributed by any 1069 parallelisms. Since the optimizer internally uses parameter IDs to represent 1070 a parameter, there will be a conversion from the parameter IDs to the 1071 canonical FQNs when calling this API. 1072 1073 ``get_state_dict`` can also process a module that is not parallelized. In 1074 such a case, ``get_state_dict`` only performs one function -- converting the 1075 optimizer parameter IDs to the canonical FQNs. 1076 1077 Example: 1078 >>> # xdoctest: +SKIP 1079 >>> import torch 1080 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1081 >>> from torch.nn.parallel import DistributedDataParallel as DDP 1082 >>> from torch.distributed.checkpoint.state_dict import get_state_dict 1083 1084 >>> fsdp_model = FSDP(copy.deepcopy(model)) 1085 >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) 1086 >>> ddp_model = DDP(copy.deepcopy(model)) 1087 >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3) 1088 1089 1090 >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim) 1091 >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim) 1092 1093 >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(), 1094 >>> # the asserts will fail. 1095 >>> assert ddp_state_dict == fsdp_state_dict 1096 >>> assert ddp_optim_state == fsdp_optim_state_dict 1097 1098 1099 Args: 1100 model (nn.Module): the nn.Module to the model. 1101 optimizers (Union[None, Optimizer, Iterable[Optimizer]]): 1102 The optimizers that are used to optimize ``model``. 1103 submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters 1104 that belong to the submodules. 1105 options (StateDictOptions): the options to control how 1106 model state_dict and optimizer state_dict should be returned. See 1107 `StateDictOptions` for the details. 1108 1109 Returns: 1110 ``Tuple`` that contain model state_dict and optimizer state_dict. 1111 1112 :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType] 1113 """ 1114 1115 with _gc_context(): 1116 optimizers = ( 1117 (optimizers,) 1118 if isinstance(optimizers, torch.optim.Optimizer) 1119 else tuple(optimizers) 1120 ) 1121 info = _verify_options( 1122 model, 1123 optimizers, 1124 optim_only=False, 1125 submodules=submodules, 1126 options=options, 1127 ) 1128 model_state_dict = _get_model_state_dict(model, info) 1129 optim_state_dict = _get_optim_state_dict(model, optimizers, info) 1130 _verify_state_dict(model_state_dict, optim_state_dict, info) 1131 return model_state_dict, optim_state_dict 1132 1133 1134def _unflatten_model_state_dict( 1135 model: nn.Module, 1136 state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]], 1137) -> Dict[str, ValueType]: 1138 if not state_dict: 1139 return {} 1140 1141 if isinstance(next(iter(state_dict.keys())), nn.Module): 1142 warnings.warn( 1143 "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" 1144 "is deprecated and will be removed in 2.5. If you need this " 1145 "feature, please preprocessing the model_state_dict to achieve the " 1146 "same functionality.", 1147 FutureWarning, 1148 ) 1149 cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict) 1150 new_state_dict: Dict[str, ValueType] = {} 1151 for submodule, sub_state_dict in cast_state_dict.items(): 1152 for name, m in model.named_modules(): 1153 if m != submodule: 1154 continue 1155 1156 fqns = _get_fqns(model, name) 1157 assert len(fqns) == 1, "FQNs for a submodule should only have 1 element" 1158 prefix = f"{next(iter(fqns))}." 1159 new_state_dict.update( 1160 {prefix + subfqn: value for subfqn, value in sub_state_dict.items()} 1161 ) 1162 return new_state_dict 1163 else: 1164 return cast(Dict[str, ValueType], state_dict) 1165 1166 1167def set_model_state_dict( 1168 model: nn.Module, 1169 model_state_dict: Dict[str, ValueType], 1170 *, 1171 options: Optional[StateDictOptions] = None, 1172) -> _IncompatibleKeys: 1173 """Load the model state_dict. 1174 1175 The counterpart of ``get_model_state_dict`` to set the state_dict to the 1176 model. See ``set_state_dict`` for the detail usage. 1177 1178 Args: 1179 model (nn.Module): the nn.Module to the model. 1180 model_state_dict: (Dict[str, ValueType]): 1181 the model state_dict to load. If the key of the ``model_state_dict`` 1182 is nn.Module, the key is a submodule of ``model`` and the value should 1183 be the state_dict of the submodule. When loading the state_dict, 1184 the prefix of the submodule will be append to the state_dict. 1185 options (StateDictOptions): the options to control how 1186 model state_dict and optimizer state_dict should be loaded. See 1187 `StateDictOptions` for the details. 1188 1189 Returns: 1190 ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: 1191 * **missing_keys** is a list of str containing the missing keys 1192 * **unexpected_keys** is a list of str containing the unexpected keys 1193 1194 :type model_state_dict: typing.Dict[str, ValueType] 1195 """ 1196 model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( 1197 model, model_state_dict 1198 ) 1199 with _gc_context(): 1200 info = _verify_options(model, (), optim_only=False, options=options) 1201 1202 _verify_state_dict(model_state_dict, {}, info) 1203 return _load_model_state_dict(model, model_state_dict, info) 1204 1205 1206def set_optimizer_state_dict( 1207 model: nn.Module, 1208 optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], 1209 optim_state_dict: OptimizerStateType, 1210 *, 1211 options: Optional[StateDictOptions] = None, 1212) -> None: 1213 """Load the optimizers state_dict. 1214 1215 The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the 1216 optimizers. See ``set_state_dict`` for the detail usage. 1217 1218 Args: 1219 model (nn.Module): the nn.Module to the model. 1220 optimizers (Union[Optimizer, Iterable[Optimizer]]): 1221 The optimizers that are used to optimize ``model``. 1222 optim_state_dict: OptimizerStateType: 1223 the optimizer state_dict to load. 1224 options (StateDictOptions): the options to control how 1225 model state_dict and optimizer state_dict should be loaded. See 1226 `StateDictOptions` for the details. 1227 1228 Returns: 1229 None 1230 1231 :type optim_state_dict: typing.OptimizerStateType 1232 """ 1233 with _gc_context(): 1234 optimizers = ( 1235 (optimizers,) 1236 if isinstance(optimizers, torch.optim.Optimizer) 1237 else tuple(optimizers) 1238 ) 1239 info = _verify_options(model, optimizers, optim_only=True, options=options) 1240 1241 _verify_state_dict({}, optim_state_dict, info) 1242 _load_optim_state_dict(model, optimizers, optim_state_dict, info) 1243 1244 1245def set_state_dict( 1246 model: nn.Module, 1247 optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]], 1248 *, 1249 model_state_dict: Dict[str, ValueType], 1250 optim_state_dict: OptimizerStateType, 1251 options: Optional[StateDictOptions] = None, 1252) -> _IncompatibleKeys: 1253 """Load the model state_dict and optimizers state_dict. 1254 1255 The counterpart of ``get_state_dict`` to set the state_dict to the model and 1256 optimizers. The given ``model_state_dict`` and ``optim_state_dict`` do not 1257 have to be returned by ``get_state_dict`` but must meet the following 1258 requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``, 1259 2) if a tensor is sharded, it must be either a ShardedTensor or DTensor, 1260 3) optimizer state_dict cannot contain the parameter IDs; the keys should be 1261 the canonical FQNs. 1262 1263 Args: 1264 model (nn.Module): the nn.Module to the model. 1265 optimizers (Union[Optimizer, Iterable[Optimizer]]): 1266 The optimizers that are used to optimize ``model``. 1267 model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 1268 the model state_dict to load. If the key of the ``model_state_dict`` 1269 is nn.Module, the key is a submodule of ``model`` and the value should 1270 be the state_dict of the submodule. When loading the state_dict, 1271 the prefix of the submodule will be append to the state_dict. 1272 optim_state_dict: OptimizerStateType: 1273 the optimizer state_dict to load. 1274 options (StateDictOptions): the options to control how 1275 model state_dict and optimizer state_dict should be loaded. See 1276 `StateDictOptions` for the details. 1277 1278 Returns: 1279 ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: 1280 * **missing_keys** is a list of str containing the missing keys of the model state_dict. 1281 * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict. 1282 1283 :type model_state_dict: typing.Dict[str, ValueType] 1284 :type optim_state_dict: typing.OptimizerStateType 1285 """ 1286 1287 model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict( 1288 model, model_state_dict 1289 ) 1290 with _gc_context(): 1291 optimizers = ( 1292 (optimizers,) 1293 if isinstance(optimizers, torch.optim.Optimizer) 1294 else tuple(optimizers) 1295 ) 1296 info = _verify_options( 1297 model, optimizers, optim_only=not model_state_dict, options=options 1298 ) 1299 1300 _verify_state_dict(model_state_dict, optim_state_dict, info) 1301 _load_optim_state_dict(model, optimizers, optim_state_dict, info) 1302 return _load_model_state_dict(model, model_state_dict, info) 1303 1304 1305# TODO: correct the state_dict function signature. 1306# TODO: this API is not yet fully tested. Make it private 1307@no_type_check 1308def _patch_model_state_dict( 1309 model: nn.Module, 1310 *, 1311 options: Optional[StateDictOptions] = None, 1312) -> None: 1313 """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``. 1314 1315 Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to 1316 be a partial function to call ``get_state_dict`` and ``set_state_dict``. 1317 1318 Example: 1319 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1320 from torch.distributed.checkpoint.state_dict import patch_model_state_dict 1321 1322 model = fsdp(model) 1323 patch_model_state_dict(model) 1324 1325 Args: 1326 model (nn.Module): the nn.Module to the model. 1327 options (StateDictOptions): the options to control how 1328 model state_dict and optimizer state_dict should be loaded. See 1329 `StateDictOptions` for the details. 1330 Returns: 1331 None 1332 """ 1333 1334 _state_dict_call = functools.partial( 1335 get_model_state_dict, 1336 model=model, 1337 options=options, 1338 ) 1339 1340 def state_dict_call(): 1341 return _state_dict_call() 1342 1343 model.state_dict = state_dict_call 1344 1345 _load_state_dict_call = functools.partial( 1346 set_model_state_dict, 1347 model=model, 1348 options=options, 1349 ) 1350 1351 def load_state_dict_call(state_dict: Dict[str, Any]): 1352 _load_state_dict_call(model_state_dict=state_dict) 1353 1354 model.load_state_dict = load_state_dict_call 1355 1356 _patched_state_dict.add(state_dict_call) 1357 _patched_state_dict.add(load_state_dict_call) 1358 1359 1360# TODO: correct the load_state_dict function signature. 1361# TODO: this API is not yet fully tested. Make it private 1362@no_type_check 1363def _patch_optimizer_state_dict( 1364 model: nn.Module, 1365 *, 1366 optimizers: Tuple[torch.optim.Optimizer, ...], 1367 options: Optional[StateDictOptions] = None, 1368) -> None: 1369 """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``. 1370 1371 Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to 1372 be a partial function to call ``get_state_dict`` and ``set_state_dict``. 1373 1374 Note that if there are multiple optimizers, all of the optimizers will be patched. 1375 So users only need to call one of the state_dict() to get the full result. 1376 1377 Example: 1378 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1379 from torch.distributed.checkpoint.state_dict import patch_model_state_dict 1380 1381 model = fsdp(model) 1382 patch_model_state_dict(model) 1383 1384 Args: 1385 model (nn.Module): the nn.Module to the model. 1386 options (StateDictOptions): the options to control how 1387 model state_dict and optimizer state_dict should be loaded. See 1388 `StateDictOptions` for the details. 1389 Returns: 1390 None 1391 """ 1392 1393 _state_dict_call = functools.partial( 1394 get_optimizer_state_dict, 1395 model=model, 1396 optimizers=optimizers, 1397 options=options, 1398 ) 1399 1400 def state_dict_call(): 1401 return _state_dict_call() 1402 1403 _load_state_dict_call = functools.partial( 1404 set_optimizer_state_dict, 1405 model=model, 1406 optimizers=optimizers, 1407 options=options, 1408 ) 1409 1410 def load_state_dict_call(state_dict: Dict[str, Any]): 1411 _load_state_dict_call(optim_state_dict=state_dict) 1412 1413 _patched_state_dict.add(state_dict_call) 1414 _patched_state_dict.add(load_state_dict_call) 1415 optimizers = ( 1416 (optimizers,) 1417 if isinstance(optimizers, torch.optim.Optimizer) 1418 else tuple(optimizers) 1419 ) 1420 for optim in optimizers: 1421 optim.state_dict = state_dict_call 1422 optim.load_state_dict = load_state_dict_call 1423