xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/state_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4import gc
5import warnings
6from dataclasses import asdict, dataclass, field
7from itertools import chain
8from typing import (
9    Any,
10    Callable,
11    cast,
12    Dict,
13    Generator,
14    Iterable,
15    List,
16    no_type_check,
17    Optional,
18    Set,
19    Tuple,
20    Union,
21)
22
23import torch
24import torch.distributed as dist
25import torch.nn as nn
26from torch.distributed._shard.sharded_tensor import ShardedTensor
27from torch.distributed._state_dict_utils import (
28    _broadcast_state_dict,
29    _distribute_state_dict,
30    _flatten_state_dict,
31    _gather_state_dict,
32    _offload_state_dict_to_cpu,
33    _unflatten_state_dict,
34)
35from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
36    _CHECKPOINT_PREFIX,
37)
38from torch.distributed.fsdp import (
39    FullOptimStateDictConfig,
40    FullStateDictConfig,
41    FullyShardedDataParallel as FSDP,
42    OptimStateDictConfig,
43    ShardedOptimStateDictConfig,
44    ShardedStateDictConfig,
45    StateDictConfig,
46    StateDictType,
47)
48from torch.distributed.fsdp._common_utils import (
49    _get_module_fsdp_state_if_fully_sharded_module,
50    FSDP_WRAPPED_MODULE,
51)
52from torch.distributed.tensor import DTensor
53from torch.nn.modules.module import _IncompatibleKeys
54from torch.nn.parallel import DistributedDataParallel as DDP
55from torch.utils._pytree import tree_map_only
56
57
58__all__ = [
59    "FQNS_T",
60    "PrimitiveType",
61    "ValueType",
62    "DictValueType",
63    "ListDictValueType",
64    "OptimizerStateType",
65    "StateDictOptions",
66    "get_model_state_dict",
67    "get_optimizer_state_dict",
68    "get_state_dict",
69    "set_model_state_dict",
70    "set_optimizer_state_dict",
71    "set_state_dict",
72]
73
74
75_FLAT_PARAM = "_flat_param"
76_PG = "param_groups"
77_PARAMS = "params"
78_STATE = "state"
79
80FQNS_T = Set[str]
81PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
82ValueType = Union[
83    PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"]
84]
85DictValueType = Dict[str, ValueType]
86ListDictValueType = List[DictValueType]
87OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]
88
89
90_patched_state_dict: Set[Callable] = set()
91
92
93@contextlib.contextmanager
94def _gc_context():
95    is_enabled = gc.isenabled()
96    gc.disable()
97    try:
98        yield
99    finally:
100        if is_enabled:
101            gc.enable()
102
103
104@dataclass
105class StateDictOptions:
106    """
107    This dataclass specifies how get_state_dict/set_state_dict will work.
108
109    - ``full_state_dict``: if this is set to True, all the tensors in the
110      returned state_dict will be gathered. No ShardedTensor and DTensor
111      will be in the returned state_dict.
112
113    - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
114      ``full_state_dict`` is also true, then only the rank0 will get the
115      state_dict and all other ranks will get empty state_dict.
116
117    - ``ignore_frozen_params``: if the value is True, the returned state_dict
118      won't contain any frozen parameters -- the ``requires_grad`` is False.
119      The default value is False.
120
121    - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option
122      indicates whether to keep the submodule prefixes from the state_dict keys.
123      or example, if the submodule is ``module.pretrain`` and the full FQN of
124      the parameter is ``pretrain.layer1.weight`` of the param. When this option
125      is True, the parameter's key in the returned state_dict will be
126      ``pretrain.layer1.weight``. If the options is False, the key will be
127      ``layer1.weight``.
128      Note that if ``keep_submodule_prefixes`` is False, there may be conflicted
129      FQNs, hence there should be only one submodule in ``submodules``.
130
131    - ``strict``: the ``strict`` option when ``set_state_dict`` calls
132      model.load_state_dict().
133
134    - ``broadcast_from_rank0``: when the option is True, rank0 should receive a
135       full state_dict and will broadcast the tensors in the state_dict/
136       optim_state_dict one by one to other ranks. Other ranks will receive
137       the tensors and shard according to the local shards in the model and
138       optimizer. ``full_state_dict`` must be set to True when using this option.
139       This option currently only supports DTensor, not the legacy ShardedTensor.
140    """
141
142    full_state_dict: bool = False
143    cpu_offload: bool = False
144    ignore_frozen_params: bool = False
145    keep_submodule_prefixes: bool = True
146    strict: bool = True
147    broadcast_from_rank0: bool = False
148    flatten_optimizer_state_dict: bool = False
149
150
151@dataclass
152class _StateDictInfo(StateDictOptions):
153    fqn_param_mapping: Dict[
154        Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
155    ] = field(default_factory=dict)
156    shared_params_mapping: Dict[
157        Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
158    ] = field(default_factory=dict)
159    submodule_prefixes: Set[str] = field(default_factory=set)
160    handle_model: bool = True
161    handle_optim: bool = True
162    fsdp_context: Callable = contextlib.nullcontext
163    fsdp_modules: List[nn.Module] = field(default_factory=list)
164
165
166@functools.lru_cache(maxsize=None)
167def _get_fqns(
168    model: nn.Module,
169    name: str,
170    skip_ddp_prefix: bool = True,
171    skip_compiler_prefix: bool = True,
172) -> FQNS_T:
173    """
174    This API is used to convert the name of a parameter to the FQNs. For FSDP
175    without `use_orig_params`, the name of FlatParameter can be mapped to
176    multiple original parameters. As a result, the return type of this function
177    is `Set[str]`.
178
179    Args:
180        module (nn.Module): the root model.
181        name (str): the name
182        skip_ddp_prefix (bool): whether to skip DDP's `module` prefix
183
184    Returns:
185        The canonical FQNs based on the model traversal.
186    """
187
188    # Remove the checkpoint prefix, if it exists.
189    name = name.replace(_CHECKPOINT_PREFIX, "")
190    if "." not in name:
191        return {name}
192
193    obj_names = name.split(".")
194    fqn_obj_names = []
195    curr_obj = model
196    for i, curr_obj_name in enumerate(obj_names):
197        if isinstance(curr_obj, DDP):
198            assert curr_obj_name == "module"
199            curr_obj = curr_obj.module
200            if not skip_ddp_prefix:
201                fqn_obj_names.append(curr_obj_name)
202        elif isinstance(curr_obj, FSDP):
203            if i < len(obj_names) - 1 and obj_names[i + 1] == _FLAT_PARAM:
204                prefix = ".".join(fqn_obj_names)
205                flat_param = getattr(curr_obj, _FLAT_PARAM)
206                if prefix:
207                    prefix = f"{prefix}."
208                return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
209            curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
210            if curr_obj_name != FSDP_WRAPPED_MODULE:
211                fqn_obj_names.append(curr_obj_name)
212                curr_obj = getattr(curr_obj, curr_obj_name)
213        elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
214            assert curr_obj_name == "_orig_mod"
215            curr_obj = curr_obj._orig_mod
216            if not skip_compiler_prefix:
217                fqn_obj_names.append(curr_obj_name)
218        else:
219            fqn_obj_names.append(curr_obj_name)
220            if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
221                if i != len(obj_names) - 1:
222                    raise RuntimeError("Expect `_extra_state` to be the last obj name")
223            else:
224                curr_obj = getattr(curr_obj, curr_obj_name)
225
226    return {".".join(fqn_obj_names).replace(_CHECKPOINT_PREFIX, "")}
227
228
229class _EXTRA_STATE:
230    pass
231
232
233def _iterate_valid_model_state(model):
234    visited_modules: Set[nn.Module] = set()
235
236    def recurse(module: nn.Module, curr_fqn: str) -> Generator:
237        visited_modules.add(module)
238
239        curr_fqn = f"{curr_fqn}." if curr_fqn else ""
240        for name, submodule in module.named_children():
241            if submodule in visited_modules:
242                continue
243            new_fqn = f"{curr_fqn}{name}"
244            yield from recurse(submodule, new_fqn)
245
246        for name, obj in chain(
247            module.named_buffers(recurse=False), module.named_parameters(recurse=False)
248        ):
249            if name in module._non_persistent_buffers_set:
250                continue
251            new_fqn = f"{curr_fqn}{name}"
252            yield new_fqn, obj
253
254        if (
255            getattr(module.__class__, "get_extra_state", nn.Module.get_extra_state)
256            != nn.Module.get_extra_state
257        ):
258            new_fqn = f"{curr_fqn}{nn.modules.module._EXTRA_STATE_KEY_SUFFIX}"
259            yield new_fqn, _EXTRA_STATE()
260
261    yield from recurse(model, "")
262
263
264def _verify_options(
265    model: nn.Module,
266    optims: Tuple[torch.optim.Optimizer, ...],
267    optim_only: bool,
268    *,
269    submodules: Optional[Set[nn.Module]] = None,
270    options: Optional[StateDictOptions] = None,
271) -> _StateDictInfo:
272    """
273    Verify the model and options passed by the user and generates _StateDictInfo.
274    """
275    if submodules:
276        warnings.warn(
277            "Getting submodules only model/optim state_dict is deprecated and "
278            "will be removed in 2.5. This feature can be achieved by manually "
279            "filtering out the state_dict returned from get_state_dict.",
280            FutureWarning,
281        )
282    if optim_only and not optims:
283        raise RuntimeError(
284            "Optimizers are not passed in but optim_only is set to True."
285        )
286
287    options = options or StateDictOptions()
288
289    fqn_param_mapping: Dict[
290        Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
291    ] = {}
292    shared_params_mapping: Dict[
293        Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
294    ] = {}
295    for name, param in _iterate_valid_model_state(model):
296        if isinstance(param, _EXTRA_STATE):
297            continue
298
299        fqns = _get_fqns(model, name)
300        fqn = fqn_param_mapping.get(param, None)
301        if fqn is not None:
302            cast(Set[str], fqn_param_mapping[param]).update(fqns)
303            shared_params_mapping[param] = fqn_param_mapping[param]
304        else:
305            # We need to do copy as _get_fqns is lru_cached
306            fqn_param_mapping[param] = fqns.copy()
307        for fqn in fqns:
308            if not isinstance(param, _EXTRA_STATE):
309                fqn_param_mapping[fqn] = param
310
311    for param_, fqns_ in list(shared_params_mapping.items()):
312        for fqn in fqns_:
313            shared_params_mapping[fqn] = cast(torch.Tensor, param_)
314
315    submodule_prefixes: Set[str] = set()
316    if submodules:
317        submodules = set(submodules)
318        for name, module in model.named_modules():
319            if module not in submodules:
320                continue
321            fqns = _get_fqns(model, name)
322            assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
323            submodule_prefixes.update(f"{fqn}." for fqn in fqns)
324
325    if options.broadcast_from_rank0 and not options.full_state_dict:
326        raise ValueError(
327            "full_state_dict must be True when broadcast_from_rank0 is True."
328        )
329    fsdp_modules = FSDP.fsdp_modules(model)
330    state_dict_config: StateDictConfig
331    optim_state_dict_config: OptimStateDictConfig
332    fsdp_context: Callable
333    if fsdp_modules:
334        # FSDP API only work if at least one FSDP instance exists.
335        if options.full_state_dict:
336            state_dict_config = FullStateDictConfig(
337                offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload
338            )
339            optim_state_dict_config = FullOptimStateDictConfig(
340                offload_to_cpu=options.cpu_offload,
341                rank0_only=(options.cpu_offload or options.broadcast_from_rank0),
342            )
343            state_dict_type = StateDictType.FULL_STATE_DICT
344        else:
345            state_dict_config = ShardedStateDictConfig(
346                offload_to_cpu=options.cpu_offload,
347            )
348            optim_state_dict_config = ShardedOptimStateDictConfig(
349                offload_to_cpu=options.cpu_offload,
350            )
351            state_dict_type = StateDictType.SHARDED_STATE_DICT
352
353        @contextlib.contextmanager
354        def fsdp_state_dict_type_without_warning(
355            module,
356            state_dict_type,
357            state_dict_config,
358            optim_state_dict_config,
359        ):
360            with warnings.catch_warnings():
361                warnings.filterwarnings(
362                    "ignore", message="FSDP.state_dict_type", category=FutureWarning
363                )
364                with FSDP.state_dict_type(
365                    module=module,
366                    state_dict_type=state_dict_type,
367                    state_dict_config=state_dict_config,
368                    optim_state_dict_config=optim_state_dict_config,
369                ):
370                    yield
371
372        fsdp_context = functools.partial(
373            fsdp_state_dict_type_without_warning,
374            module=model,
375            state_dict_type=state_dict_type,
376            state_dict_config=state_dict_config,
377            optim_state_dict_config=optim_state_dict_config,
378        )
379    else:
380        fsdp_context = contextlib.nullcontext
381
382    return _StateDictInfo(
383        **asdict(options),
384        fqn_param_mapping=fqn_param_mapping,
385        shared_params_mapping=shared_params_mapping,
386        submodule_prefixes=submodule_prefixes,
387        fsdp_context=fsdp_context,
388        fsdp_modules=cast(List[nn.Module], fsdp_modules),
389        handle_model=not optim_only,
390        handle_optim=(len(optims) > 0),
391    )
392
393
394def _verify_state_dict(
395    model_state_dict: Dict[str, ValueType],
396    optim_state_dict: OptimizerStateType,
397    info: _StateDictInfo,
398) -> None:
399    for module in info.fsdp_modules:
400        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
401        assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
402
403    # Verify if the model_state_dict and optim_state_dict are valid. This API
404    # should give the users an explicit error message to debug or report.
405    if (
406        info.handle_model
407        and not model_state_dict
408        and not info.submodule_prefixes
409        and not info.ignore_frozen_params
410        and not (info.cpu_offload and info.full_state_dict)
411        and info.strict
412        and not info.broadcast_from_rank0
413    ):
414        raise RuntimeError(
415            "The option indicates that model state_dict is required to save "
416            "or load, but model state_dict is empty."
417            f"rank = {dist.get_rank()=}."
418        )
419
420    if info.handle_optim:
421        if (
422            not optim_state_dict
423            and not (info.cpu_offload and info.full_state_dict)
424            and (not info.broadcast_from_rank0)
425        ):
426            raise RuntimeError(
427                "The option indicates that model state_dict is required to save, "
428                f"or load but optim state_dict is empty. {optim_state_dict}"
429            )
430
431    for key in model_state_dict.keys():
432        if _FLAT_PARAM in key:
433            raise RuntimeError(
434                f"{key} contains {_FLAT_PARAM}. This can happen if the model "
435                "is not the root module."
436            )
437
438
439def _state_dict_fn(obj: Union[nn.Module, torch.optim.Optimizer], api: str) -> Callable:
440    call = getattr(obj, api)
441    if call in _patched_state_dict:
442        call = functools.partial(getattr(obj.__class__, api), self=obj)
443    return call
444
445
446def _maybe_full_or_cpu_state_dict(
447    state_dict: Dict[str, Any], info: _StateDictInfo
448) -> Dict[str, Any]:
449    if info.full_state_dict:
450        ranks_only = (
451            ()
452            if (not info.cpu_offload or not torch.distributed.is_initialized())
453            else (0,)
454        )
455        return _gather_state_dict(
456            state_dict, cpu_offload=info.cpu_offload, ranks_only=ranks_only
457        )
458    elif info.cpu_offload:
459        return _offload_state_dict_to_cpu(state_dict)
460    else:
461        return state_dict
462
463
464@torch.no_grad()
465def _get_model_state_dict(
466    model: nn.Module, info: _StateDictInfo
467) -> Dict[str, ValueType]:
468    if not info.handle_model:
469        return {}
470
471    with info.fsdp_context():
472        state_dict = _state_dict_fn(model, "state_dict")()
473
474    for key in list(state_dict.keys()):
475        fqns = _get_fqns(model, key)
476        assert len(fqns) == 1, (key, fqns)
477        fqn = next(iter(fqns))
478        if fqn != key:
479            # As we only support FSDP, DDP, and TP, the only cases are
480            # wrapper-based DDP and compiler. Verify if the assumption
481            # is correct.
482            def verify(key, fqn) -> bool:
483                if len(fqn) >= len(key):
484                    return False
485                fqn_split = fqn.split(".")
486                key_split = key.split(".")
487                fqn_idx = 0
488                for key_idx, key_name in enumerate(key_split):
489                    if key_name == fqn_split[fqn_idx]:
490                        fqn_idx += 1
491                        if fqn_idx == len(fqn_split):
492                            return key_idx == len(key_split) - 1
493                    elif key_name in ("module", "_orig_mod"):
494                        continue
495                    else:
496                        return False
497                return True
498
499            if not verify(key, fqn):
500                raise RuntimeError(f"An unexpected key, {key}, exists. FQN is {fqn}")
501            state_dict[fqn] = state_dict.pop(key)
502
503    if info.submodule_prefixes:
504        new_state_dict: Dict[str, ValueType] = {}
505        # TODO: make this faster.
506        for fqn in state_dict.keys():
507            for prefix in info.submodule_prefixes:
508                if not fqn.startswith(prefix):
509                    continue
510                if info.keep_submodule_prefixes:
511                    new_state_dict[fqn] = state_dict[fqn]
512                else:
513                    new_fqn = fqn[len(prefix) :]
514                    new_state_dict[new_fqn] = state_dict[fqn]
515        state_dict = new_state_dict
516
517    if info.ignore_frozen_params:
518        for key, param in model.named_parameters():
519            if param.requires_grad:
520                continue
521            fqns = _get_fqns(model, key)
522            for fqn in fqns:
523                state_dict.pop(fqn)
524
525    for key, p in list(state_dict.items()):
526        if torch.is_tensor(p) and p.is_meta:
527            state_dict.pop(key)
528
529    return _maybe_full_or_cpu_state_dict(state_dict, info)
530
531
532@torch.no_grad()
533def _load_model_state_dict(
534    model: nn.Module,
535    state_dict: Dict[str, ValueType],
536    info: _StateDictInfo,
537) -> _IncompatibleKeys:
538    if not info.handle_model or (not state_dict and not info.broadcast_from_rank0):
539        return _IncompatibleKeys({}, {})
540
541    local_state_dict = {}
542    for key, value in _iterate_valid_model_state(model):
543        fqns = _get_fqns(model, key)
544        fqns_with_prefix = _get_fqns(
545            model, key, skip_ddp_prefix=False, skip_compiler_prefix=False
546        )
547
548        for fqn, fqn_with_prefix in zip(fqns, fqns_with_prefix):
549            if (
550                not info.broadcast_from_rank0 or dist.get_rank() == 0
551            ) and fqn != fqn_with_prefix:
552                state_dict[fqn_with_prefix] = state_dict.pop(fqn)
553            local_state_dict[fqn_with_prefix] = value
554
555    assign = False
556    if info.broadcast_from_rank0 or info.full_state_dict:
557        device = None
558        for key, value in local_state_dict.items():
559            if torch.is_tensor(value) and value.dim() > 0:
560                if device is None:
561                    device = value.device
562                else:
563                    assert device == value.device
564        assert device is not None
565        if device == torch.device("meta"):
566            device = dist.distributed_c10d._get_pg_default_device()
567            assign = True
568        if info.broadcast_from_rank0:
569            _broadcast_state_dict(
570                state_dict, local_state_dict, device=device, strict=info.strict
571            )
572        elif info.full_state_dict:
573            _distribute_state_dict(state_dict, local_state_dict, device=device)
574        for fqn, local_state in local_state_dict.items():
575            state_dict[fqn] = local_state
576
577    with info.fsdp_context():
578        return cast(
579            _IncompatibleKeys,
580            _state_dict_fn(model, "load_state_dict")(
581                state_dict=state_dict, strict=info.strict, assign=assign
582            ),
583        )
584
585
586def _init_optim_state(optim: torch.optim.Optimizer) -> None:
587    """
588    Initialize optim states by calling the step() with zero grads.
589    """
590    if optim.state:
591        # The optimizer state is initialized.
592        return
593
594    # There are some stateless optimizers like SGD. These optimizer will
595    # not return in the above condition. So if gradients exist, we should also
596    # return. If gradients do not exist, the following initialization should
597    # not disturb SGD because the gradients and lr are both zero.
598    for param_group in optim.param_groups:
599        for param in param_group[_PARAMS]:
600            if param.grad is not None:
601                return
602
603    for param_group in optim.param_groups:
604        for param in param_group[_PARAMS]:
605            if param.requires_grad:
606                param.grad = torch.zeros_like(param)
607
608    # Some optimizers will update parameters regardless of grads due to lr, so
609    # make lr to zero when calling `step()`.
610    lrs = []
611    for param_group in optim.param_groups:
612        if "lr" in param_group:
613            lrs.append(param_group["lr"])
614            param_group["lr"] = 0.0
615    optim.step(closure=None)
616    # Whether to recover the "lr" should not matter too much as we will
617    # restore checkpointing later.
618    for param_group in optim.param_groups:
619        if "lr" in param_group:
620            param_group["lr"] = lrs.pop(0)
621    optim.zero_grad(set_to_none=True)
622
623
624def _flatten_optim_state_dict(state_dict: OptimizerStateType) -> Dict[str, ValueType]:
625    """
626    This API flattens the optimizer state_dict to support optimizer resharding for
627    MPMD, e.g., pipeline parallelism.
628
629    Without the API, the original optimizer state_dict looks like:
630    {
631        "state": {
632            "layer1.weight": {
633                "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
634            },
635            "layer2.weight": {
636                "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
637            },
638        },
639        "param_group": [
640            {
641                "lr": 0.0,
642                "betas": (0.9, 0.95), ...,
643                "params": ["layer1.weight", "layer2.weight"]
644            }
645        ]
646    }
647
648    With this API, the optimizer state_dict looks like:
649    {
650        "state.layer1.weight.step": 10,
651        "state.layer2.weight.step": 10,
652        "state.layer1.weight.exp_avg": SomeTensor,
653        "state.layer2.weight.exp_avg": SomeTensor,
654        "state.layer1.weight.exp_avg_sq": SomeTensor,
655        "state.layer2.weight.exp_avg_sq": SomeTensor,
656        "param_group.layer1.weight.lr" : 0.1,
657        "param_group.layer2.weight.lr" : 0.1,
658        "param_group.layer1.weight.betas" : (0.9, 0.95),
659        "param_group.layer2.weight.betas" : (0.9, 0.95),
660    }
661
662    Note that if any of the value is a container, like the betas in the example,
663    this API won't flattent it.
664    """
665
666    def _raise_if_type_not_supported(v):
667        if not isinstance(v, (torch.Tensor, int, float)):
668            raise NotImplementedError(
669                "Flattening optimizer state_dict only supports "
670                "tensor, int, float states now. "
671                f"Type is {type(v)}."
672            )
673
674    ret: Dict[str, ValueType] = {}
675    for fqn, state in cast(DictValueType, state_dict[_STATE]).items():
676        for k, v in cast(DictValueType, state).items():
677            _raise_if_type_not_supported(v)
678            ret[f"{_STATE}.{fqn}.{k}"] = v
679
680    for param_group in cast(ListDictValueType, state_dict[_PG]):
681        fqns = param_group.pop(_PARAMS)
682        for fqn in cast(List[str], fqns):
683            for k, v in param_group.items():
684                ret[f"{_PG}.{fqn}.{k}"] = v
685    return ret
686
687
688def _unflatten_optim_state_dict(
689    optim: torch.optim.Optimizer,
690    state_dict: Dict[str, ValueType],
691    info: _StateDictInfo,
692) -> OptimizerStateType:
693    """
694    This API unflattens the state_dict generated by _flatten_optim_state_dict().
695    See the docstring of _flatten_optim_state_dict() for more detail.
696    """
697    state: DictValueType = {}
698    pg_state: ListDictValueType = []
699    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
700
701    for param_group in optim.param_groups:
702        pg_state.append({_PARAMS: []})
703        for param in param_group[_PARAMS]:
704            for fqn in info.fqn_param_mapping[param]:
705                params = pg_state[-1][_PARAMS]
706                assert isinstance(params, list)  # typing
707                params.append(fqn)
708                if not param.requires_grad:
709                    continue
710                state[fqn] = {}
711                for state_name in optim.state[param].keys():
712                    cast(DictValueType, state[fqn])[state_name] = state_dict[
713                        f"{_STATE}.{fqn}.{state_name}"
714                    ]
715
716        first_param_fqn = cast(List[str], pg_state[-1][_PARAMS])[0]
717        for k in param_group.keys():
718            if k == _PARAMS:
719                continue
720            value = state_dict[f"{_PG}.{first_param_fqn}.{k}"]
721            if k not in pg_state[-1]:
722                pg_state[-1][k] = value
723            elif pg_state[-1][k] != value:
724                raise RuntimeError(
725                    "All the parameters in the same parameter group should have "
726                    f"the same saved param_group value. But {first_param_fqn}.{k} "
727                    f"is {value} while other(s) is {pg_state[-1][k]}."
728                )
729
730    return return_osd
731
732
733@torch.no_grad()
734def _get_optim_state_dict(
735    model: nn.Module,
736    optimizers: Tuple[torch.optim.Optimizer, ...],
737    info: _StateDictInfo,
738) -> OptimizerStateType:
739    if not info.handle_optim:
740        return {}
741
742    optim_state_dict: OptimizerStateType = {_STATE: {}, _PG: []}
743    for optim in optimizers:
744        _init_optim_state(optim)
745        osd = _state_dict_fn(optim, "state_dict")()
746        if info.fsdp_modules:
747            with info.fsdp_context():
748                osd = FSDP.optim_state_dict(model, optim, osd)
749
750            # We need to specially handle FlatParameter FSDP as
751            # FlatParameter FSDP converts the FQNs.
752            # There are no easy ways to do this conversion systematically.
753            # We can only use a string replacment without correctness check.
754            if not osd:
755                continue
756            for k in list(osd[_STATE].keys()):
757                if "_orig_mod" in k:
758                    osd[_STATE][k.replace("_orig_mod.", "")] = osd[_STATE].pop(k)
759            for g in osd[_PG]:
760                params = [k.replace("_orig_mod.", "") for k in g[_PARAMS]]
761                g[_PARAMS] = params
762        else:
763            params = list(chain.from_iterable(g[_PARAMS] for g in optim.param_groups))
764            param_pid_mapping = dict(zip(params, range(len(params))))
765            fqn_pid_mapping = {}
766            for key, param in model.named_parameters():
767                fqns = _get_fqns(model, key)
768                assert len(fqns) == 1
769                fqn = next(iter(fqns))
770                if param not in param_pid_mapping:
771                    continue
772                pid = param_pid_mapping[param]
773                fqn_pid_mapping[fqn] = pid
774                fqn_pid_mapping[pid] = fqn
775
776            for key in list(osd[_STATE].keys()):
777                fqn = fqn_pid_mapping[key]
778                osd[_STATE][fqn] = osd[_STATE].pop(key)
779
780            for group in osd[_PG]:
781                group[_PARAMS] = [fqn_pid_mapping[pid] for pid in group[_PARAMS]]
782
783        if not osd:
784            continue
785
786        cast(DictValueType, optim_state_dict[_STATE]).update(osd[_STATE])
787        cast(ListDictValueType, optim_state_dict[_PG]).extend(osd[_PG])
788
789    if info.flatten_optimizer_state_dict:
790        optim_state_dict = cast(
791            OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)
792        )
793
794    return _maybe_full_or_cpu_state_dict(optim_state_dict, info)
795
796
797def _split_optim_state_dict(
798    model: nn.Module,
799    optim: torch.optim.Optimizer,
800    optim_state_dict: OptimizerStateType,
801    info: _StateDictInfo,
802) -> OptimizerStateType:
803    """
804    Extract the corresponding optim state_dict from ``optim_state_dict`` for
805    ``optim`` and return the result optim state_dict.
806
807    Args:
808        model (nn.Module): the root model.
809        optim (torch.optim.Optimizer): the optimizer.
810        optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
811            contains the optim state_dict of ``optim``.
812        info (_StateDictInfo): state dict information.
813
814    Returns:
815        The optim state_dict of ``optim``.
816    """
817
818    state: DictValueType = {}
819    pg_state: ListDictValueType = []
820    return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state}
821    pg_mapping: Dict[int, int] = {}
822
823    if all(
824        isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys()
825    ):
826        return optim_state_dict
827
828    for param_group in optim.param_groups:
829        pg_state.append({_PARAMS: []})
830        for param in param_group[_PARAMS]:
831            for fqn in info.fqn_param_mapping[param]:
832                if fqn in info.shared_params_mapping:
833                    in_params = False
834                    for loaded_param_group in cast(
835                        ListDictValueType, optim_state_dict[_PG]
836                    ):
837                        if fqn in cast(List[str], loaded_param_group[_PARAMS]):
838                            in_params = True
839                            break
840                else:
841                    in_params = True
842                if not in_params:
843                    continue
844
845                params = pg_state[-1][_PARAMS]
846                assert isinstance(params, list)
847                params.append(fqn)
848                if param.requires_grad:
849                    state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
850                for loaded_param_group in cast(
851                    ListDictValueType, optim_state_dict[_PG]
852                ):
853                    if fqn in cast(List[str], loaded_param_group[_PARAMS]):
854                        pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
855
856    for param_group in cast(ListDictValueType, optim_state_dict[_PG]):
857        idx = pg_mapping.get(id(param_group), -1)
858        if idx == -1:
859            continue
860        for key, value in param_group.items():
861            if key == _PARAMS:
862                continue
863            # TODO: check if value is the same if exists.
864            pg_state[idx][key] = value
865
866    return return_osd
867
868
869@torch.no_grad()
870def _load_optim_state_dict(
871    model: nn.Module,
872    optimizers: Tuple[torch.optim.Optimizer, ...],
873    state_dict: OptimizerStateType,
874    info: _StateDictInfo,
875) -> None:
876    if not info.handle_optim:
877        return
878
879    for optim in optimizers:
880        _init_optim_state(optim)
881        if state_dict:
882            if _STATE in state_dict:
883                optim_state_dict = _split_optim_state_dict(
884                    model, optim, state_dict, info
885                )
886            else:
887                optim_state_dict = _unflatten_optim_state_dict(
888                    optim, cast(Dict[str, ValueType], state_dict), info
889                )
890        else:
891            optim_state_dict = {}
892        if info.fsdp_modules:
893            # We need to specially handle FlatParameter FSDP as
894            # FlatParameter FSDP converts the FQNs.
895            for original_fqn, _ in model.named_parameters():
896                fqns = _get_fqns(model, original_fqn)
897                fqns_with_compiler = _get_fqns(
898                    model, original_fqn, skip_compiler_prefix=False
899                )
900                if fqns == fqns_with_compiler:
901                    continue
902
903                assert len(fqns) == 1
904                fqn = fqns.pop()
905                fqn_with_compiler = fqns_with_compiler.pop()
906                for g in optim_state_dict[_PG]:
907                    val = cast(Dict[str, Any], g)
908                    params = [
909                        key.replace(fqn, fqn_with_compiler) for key in val[_PARAMS]
910                    ]
911                    val[_PARAMS] = params
912                osd_state = cast(DictValueType, optim_state_dict[_STATE])
913                for k in list(osd_state.keys()):
914                    if fqn in k:
915                        osd_state[k.replace(fqn, fqn_with_compiler)] = osd_state.pop(k)
916
917            with info.fsdp_context():
918                optim_state_dict = FSDP.optim_state_dict_to_load(
919                    model, optim, optim_state_dict
920                )
921        elif info.full_state_dict:
922            info.full_state_dict = False
923            local_state_dict = _get_optim_state_dict(model, (optim,), info)
924            info.full_state_dict = True
925            device = None
926
927            def _device(t):
928                if t.dim() > 0:
929                    nonlocal device
930                    if device is None:
931                        device = t.device
932                    elif device != t.device:
933                        raise ValueError("Device mismatch")
934                return t
935
936            _ = tree_map_only(torch.Tensor, _device, local_state_dict)
937            assert device is not None
938            flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)
939            flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)
940            if info.broadcast_from_rank0:
941                _broadcast_state_dict(flatten_osd, flatten_local_osd, device=device)
942            else:
943                _distribute_state_dict(flatten_osd, flatten_local_osd, device=device)
944            # The modifications listed seek to address the problem where optim might possess
945            # dissimilar parameters in comparison to optim_state_dict. This is achieved by
946            # incorporating differential parameters within local, which may result in optim
947            # having additional parameters ultimately.
948            for optim_key in flatten_osd.keys():
949                if optim_key not in flatten_local_osd:
950                    assert optim_key in osd_mapping
951                    flatten_local_osd[optim_key] = flatten_osd[optim_key]
952                    local_osd_mapping[optim_key] = osd_mapping[optim_key]
953            optim_state_dict = _unflatten_state_dict(
954                flatten_local_osd, local_osd_mapping
955            )
956
957        # Note that we do not have to convert the FQN back to param id here if
958        # order in optim.param_groups[idx][_PARAMS] is the same as the one in
959        # optim_state_dict[_PG][idx][_PARAMS].
960        _state_dict_fn(optim, "load_state_dict")(state_dict=optim_state_dict)
961
962
963def get_model_state_dict(
964    model: nn.Module,
965    *,
966    submodules: Optional[Set[nn.Module]] = None,
967    options: Optional[StateDictOptions] = None,
968) -> Dict[str, ValueType]:
969    """
970    Return the model state_dict of ``model``.
971
972    See ``get_state_dict`` for the detail usage.
973
974    Args:
975        model (nn.Module): the nn.Module to the model.
976        submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
977            that belong to the submodules.
978        options (StateDictOptions): the options to control how
979            model state_dict and optimizer state_dict should be returned. See
980            `StateDictOptions` for the details.
981
982    Returns:
983        The state_dict for ``model``.
984
985    :rtype: typing.Dict[str, ValueType]
986    """
987    with _gc_context():
988        info = _verify_options(
989            model,
990            (),
991            optim_only=False,
992            submodules=submodules,
993            options=options,
994        )
995        model_state_dict = _get_model_state_dict(model, info)
996        _verify_state_dict(model_state_dict, {}, info)
997        return model_state_dict
998
999
1000def get_optimizer_state_dict(
1001    model: nn.Module,
1002    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
1003    *,
1004    submodules: Optional[Set[nn.Module]] = None,
1005    options: Optional[StateDictOptions] = None,
1006) -> OptimizerStateType:
1007    """
1008    Return the combined state_dict for optimizers.
1009
1010    See ``get_state_dict`` for the detail usage.
1011
1012    Args:
1013        model (nn.Module): the nn.Module to the model.
1014        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
1015            The optimizers that are used to optimize ``model``.
1016        submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
1017            that belong to the submodules.
1018        options (StateDictOptions): the options to control how
1019            model state_dict and optimizer state_dict should be returned. See
1020            `StateDictOptions` for the details.
1021
1022    Returns:
1023        The state_dict for ``optimizers``.
1024
1025    :rtype: OptimizerStateType
1026    """
1027    with _gc_context():
1028        optimizers = (
1029            (optimizers,)
1030            if isinstance(optimizers, torch.optim.Optimizer)
1031            else tuple(optimizers)
1032        )
1033        info = _verify_options(
1034            model,
1035            optimizers,
1036            optim_only=True,
1037            submodules=submodules,
1038            options=options,
1039        )
1040        optim_state_dict = _get_optim_state_dict(model, optimizers, info)
1041        _verify_state_dict({}, optim_state_dict, info)
1042        return optim_state_dict
1043
1044
1045def get_state_dict(
1046    model: nn.Module,
1047    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
1048    *,
1049    submodules: Optional[Set[nn.Module]] = None,
1050    options: Optional[StateDictOptions] = None,
1051) -> Tuple[Dict[str, ValueType], OptimizerStateType]:
1052    """
1053    Return the model state_dict and optimizers state_dict.
1054
1055    ``get_state_dict`` can process any module that is parallelized by PyTorch
1056    FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any
1057    combination of these parallelisms. The main functions of ``get_state_dict``
1058    are: 1.) returning a model and optimizer state_dict that can be resharded
1059    with a different number of trainers and/or different parallelisms.
1060    2.) hiding the parallelism-specific state_dict APIs. Users don't have to call
1061    these APIs.
1062    3.) sanity checking the result state_dict.
1063
1064    The keys of the result state dictionary are the canonical FQNs (Fully
1065    Qualified Names).  A canonical FQN refers to the FQN based on a parameter's
1066    position in an nn.Module hierarchy. More specifically, a canonical FQN to a
1067    parameter is the FQN returned by ``module.named_parameters()`` or
1068    ``module.named_buffers()`` when the module is not distributed by any
1069    parallelisms. Since the optimizer internally uses parameter IDs to represent
1070    a parameter, there will be a conversion from the parameter IDs to the
1071    canonical FQNs when calling this API.
1072
1073    ``get_state_dict`` can also process a module that is not parallelized. In
1074    such a case, ``get_state_dict`` only performs one function -- converting the
1075    optimizer parameter IDs to the canonical FQNs.
1076
1077    Example:
1078        >>> # xdoctest: +SKIP
1079        >>> import torch
1080        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1081        >>> from torch.nn.parallel import DistributedDataParallel as DDP
1082        >>> from torch.distributed.checkpoint.state_dict import get_state_dict
1083
1084        >>> fsdp_model = FSDP(copy.deepcopy(model))
1085        >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
1086        >>> ddp_model = DDP(copy.deepcopy(model))
1087        >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
1088
1089
1090        >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
1091        >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
1092
1093        >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
1094        >>> # the asserts will fail.
1095        >>> assert ddp_state_dict == fsdp_state_dict
1096        >>> assert ddp_optim_state == fsdp_optim_state_dict
1097
1098
1099    Args:
1100        model (nn.Module): the nn.Module to the model.
1101        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
1102            The optimizers that are used to optimize ``model``.
1103        submodules (deprecated): Optional[Set[nn.Module]]: only return the model parameters
1104            that belong to the submodules.
1105        options (StateDictOptions): the options to control how
1106            model state_dict and optimizer state_dict should be returned. See
1107            `StateDictOptions` for the details.
1108
1109    Returns:
1110        ``Tuple`` that contain model state_dict and optimizer state_dict.
1111
1112    :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]
1113    """
1114
1115    with _gc_context():
1116        optimizers = (
1117            (optimizers,)
1118            if isinstance(optimizers, torch.optim.Optimizer)
1119            else tuple(optimizers)
1120        )
1121        info = _verify_options(
1122            model,
1123            optimizers,
1124            optim_only=False,
1125            submodules=submodules,
1126            options=options,
1127        )
1128        model_state_dict = _get_model_state_dict(model, info)
1129        optim_state_dict = _get_optim_state_dict(model, optimizers, info)
1130        _verify_state_dict(model_state_dict, optim_state_dict, info)
1131        return model_state_dict, optim_state_dict
1132
1133
1134def _unflatten_model_state_dict(
1135    model: nn.Module,
1136    state_dict: Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]],
1137) -> Dict[str, ValueType]:
1138    if not state_dict:
1139        return {}
1140
1141    if isinstance(next(iter(state_dict.keys())), nn.Module):
1142        warnings.warn(
1143            "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
1144            "is deprecated and will be removed in 2.5. If you need this "
1145            "feature, please preprocessing the model_state_dict to achieve the "
1146            "same functionality.",
1147            FutureWarning,
1148        )
1149        cast_state_dict = cast(Dict[nn.Module, Dict[str, ValueType]], state_dict)
1150        new_state_dict: Dict[str, ValueType] = {}
1151        for submodule, sub_state_dict in cast_state_dict.items():
1152            for name, m in model.named_modules():
1153                if m != submodule:
1154                    continue
1155
1156                fqns = _get_fqns(model, name)
1157                assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
1158                prefix = f"{next(iter(fqns))}."
1159                new_state_dict.update(
1160                    {prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
1161                )
1162        return new_state_dict
1163    else:
1164        return cast(Dict[str, ValueType], state_dict)
1165
1166
1167def set_model_state_dict(
1168    model: nn.Module,
1169    model_state_dict: Dict[str, ValueType],
1170    *,
1171    options: Optional[StateDictOptions] = None,
1172) -> _IncompatibleKeys:
1173    """Load the model state_dict.
1174
1175    The counterpart of ``get_model_state_dict`` to set the state_dict to the
1176    model. See ``set_state_dict`` for the detail usage.
1177
1178    Args:
1179        model (nn.Module): the nn.Module to the model.
1180        model_state_dict: (Dict[str, ValueType]):
1181           the model state_dict to load. If the key of the ``model_state_dict``
1182           is nn.Module, the key is a submodule of ``model`` and the value should
1183           be the state_dict of the submodule. When loading the state_dict,
1184           the prefix of the submodule will be append to the state_dict.
1185        options (StateDictOptions): the options to control how
1186            model state_dict and optimizer state_dict should be loaded. See
1187            `StateDictOptions` for the details.
1188
1189    Returns:
1190        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
1191            * **missing_keys** is a list of str containing the missing keys
1192            * **unexpected_keys** is a list of str containing the unexpected keys
1193
1194    :type model_state_dict: typing.Dict[str, ValueType]
1195    """
1196    model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
1197        model, model_state_dict
1198    )
1199    with _gc_context():
1200        info = _verify_options(model, (), optim_only=False, options=options)
1201
1202        _verify_state_dict(model_state_dict, {}, info)
1203        return _load_model_state_dict(model, model_state_dict, info)
1204
1205
1206def set_optimizer_state_dict(
1207    model: nn.Module,
1208    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
1209    optim_state_dict: OptimizerStateType,
1210    *,
1211    options: Optional[StateDictOptions] = None,
1212) -> None:
1213    """Load the optimizers state_dict.
1214
1215    The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
1216    optimizers. See ``set_state_dict`` for the detail usage.
1217
1218    Args:
1219        model (nn.Module): the nn.Module to the model.
1220        optimizers (Union[Optimizer, Iterable[Optimizer]]):
1221            The optimizers that are used to optimize ``model``.
1222        optim_state_dict: OptimizerStateType:
1223            the optimizer state_dict to load.
1224        options (StateDictOptions): the options to control how
1225            model state_dict and optimizer state_dict should be loaded. See
1226            `StateDictOptions` for the details.
1227
1228    Returns:
1229        None
1230
1231    :type optim_state_dict: typing.OptimizerStateType
1232    """
1233    with _gc_context():
1234        optimizers = (
1235            (optimizers,)
1236            if isinstance(optimizers, torch.optim.Optimizer)
1237            else tuple(optimizers)
1238        )
1239        info = _verify_options(model, optimizers, optim_only=True, options=options)
1240
1241        _verify_state_dict({}, optim_state_dict, info)
1242        _load_optim_state_dict(model, optimizers, optim_state_dict, info)
1243
1244
1245def set_state_dict(
1246    model: nn.Module,
1247    optimizers: Union[torch.optim.Optimizer, Iterable[torch.optim.Optimizer]],
1248    *,
1249    model_state_dict: Dict[str, ValueType],
1250    optim_state_dict: OptimizerStateType,
1251    options: Optional[StateDictOptions] = None,
1252) -> _IncompatibleKeys:
1253    """Load the model state_dict and optimizers state_dict.
1254
1255    The counterpart of ``get_state_dict`` to set the state_dict to the model and
1256    optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not
1257    have to be returned by ``get_state_dict`` but must meet the following
1258    requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,
1259    2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,
1260    3) optimizer state_dict cannot contain the parameter IDs; the keys should be
1261    the canonical FQNs.
1262
1263    Args:
1264        model (nn.Module): the nn.Module to the model.
1265        optimizers (Union[Optimizer, Iterable[Optimizer]]):
1266            The optimizers that are used to optimize ``model``.
1267        model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
1268           the model state_dict to load. If the key of the ``model_state_dict``
1269           is nn.Module, the key is a submodule of ``model`` and the value should
1270           be the state_dict of the submodule. When loading the state_dict,
1271           the prefix of the submodule will be append to the state_dict.
1272        optim_state_dict: OptimizerStateType:
1273            the optimizer state_dict to load.
1274        options (StateDictOptions): the options to control how
1275            model state_dict and optimizer state_dict should be loaded. See
1276            `StateDictOptions` for the details.
1277
1278    Returns:
1279        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
1280            * **missing_keys** is a list of str containing the missing keys of the model state_dict.
1281            * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.
1282
1283    :type model_state_dict: typing.Dict[str, ValueType]
1284    :type optim_state_dict: typing.OptimizerStateType
1285    """
1286
1287    model_state_dict: Dict[str, ValueType] = _unflatten_model_state_dict(
1288        model, model_state_dict
1289    )
1290    with _gc_context():
1291        optimizers = (
1292            (optimizers,)
1293            if isinstance(optimizers, torch.optim.Optimizer)
1294            else tuple(optimizers)
1295        )
1296        info = _verify_options(
1297            model, optimizers, optim_only=not model_state_dict, options=options
1298        )
1299
1300        _verify_state_dict(model_state_dict, optim_state_dict, info)
1301        _load_optim_state_dict(model, optimizers, optim_state_dict, info)
1302        return _load_model_state_dict(model, model_state_dict, info)
1303
1304
1305# TODO: correct the state_dict function signature.
1306# TODO: this API is not yet fully tested. Make it private
1307@no_type_check
1308def _patch_model_state_dict(
1309    model: nn.Module,
1310    *,
1311    options: Optional[StateDictOptions] = None,
1312) -> None:
1313    """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.
1314
1315    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to
1316    be a partial function to call ``get_state_dict`` and ``set_state_dict``.
1317
1318    Example:
1319        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1320        from torch.distributed.checkpoint.state_dict import patch_model_state_dict
1321
1322        model = fsdp(model)
1323        patch_model_state_dict(model)
1324
1325    Args:
1326        model (nn.Module): the nn.Module to the model.
1327        options (StateDictOptions): the options to control how
1328            model state_dict and optimizer state_dict should be loaded. See
1329            `StateDictOptions` for the details.
1330    Returns:
1331        None
1332    """
1333
1334    _state_dict_call = functools.partial(
1335        get_model_state_dict,
1336        model=model,
1337        options=options,
1338    )
1339
1340    def state_dict_call():
1341        return _state_dict_call()
1342
1343    model.state_dict = state_dict_call
1344
1345    _load_state_dict_call = functools.partial(
1346        set_model_state_dict,
1347        model=model,
1348        options=options,
1349    )
1350
1351    def load_state_dict_call(state_dict: Dict[str, Any]):
1352        _load_state_dict_call(model_state_dict=state_dict)
1353
1354    model.load_state_dict = load_state_dict_call
1355
1356    _patched_state_dict.add(state_dict_call)
1357    _patched_state_dict.add(load_state_dict_call)
1358
1359
1360# TODO: correct the load_state_dict function signature.
1361# TODO: this API is not yet fully tested. Make it private
1362@no_type_check
1363def _patch_optimizer_state_dict(
1364    model: nn.Module,
1365    *,
1366    optimizers: Tuple[torch.optim.Optimizer, ...],
1367    options: Optional[StateDictOptions] = None,
1368) -> None:
1369    """Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.
1370
1371    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to
1372    be a partial function to call ``get_state_dict`` and ``set_state_dict``.
1373
1374    Note that if there are multiple optimizers, all of the optimizers will be patched.
1375    So users only need to call one of the state_dict() to get the full result.
1376
1377    Example:
1378        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1379        from torch.distributed.checkpoint.state_dict import patch_model_state_dict
1380
1381        model = fsdp(model)
1382        patch_model_state_dict(model)
1383
1384    Args:
1385        model (nn.Module): the nn.Module to the model.
1386        options (StateDictOptions): the options to control how
1387            model state_dict and optimizer state_dict should be loaded. See
1388            `StateDictOptions` for the details.
1389    Returns:
1390        None
1391    """
1392
1393    _state_dict_call = functools.partial(
1394        get_optimizer_state_dict,
1395        model=model,
1396        optimizers=optimizers,
1397        options=options,
1398    )
1399
1400    def state_dict_call():
1401        return _state_dict_call()
1402
1403    _load_state_dict_call = functools.partial(
1404        set_optimizer_state_dict,
1405        model=model,
1406        optimizers=optimizers,
1407        options=options,
1408    )
1409
1410    def load_state_dict_call(state_dict: Dict[str, Any]):
1411        _load_state_dict_call(optim_state_dict=state_dict)
1412
1413    _patched_state_dict.add(state_dict_call)
1414    _patched_state_dict.add(load_state_dict_call)
1415    optimizers = (
1416        (optimizers,)
1417        if isinstance(optimizers, torch.optim.Optimizer)
1418        else tuple(optimizers)
1419    )
1420    for optim in optimizers:
1421        optim.state_dict = state_dict_call
1422        optim.load_state_dict = load_state_dict_call
1423