xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_optim_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import functools
4import logging
5import warnings
6from contextlib import ExitStack
7from dataclasses import dataclass, field
8from typing import (
9    Any,
10    cast,
11    Dict,
12    Iterable,
13    Iterator,
14    List,
15    NamedTuple,
16    no_type_check,
17    Optional,
18    Sequence,
19    Set,
20    Tuple,
21    TYPE_CHECKING,
22    Union,
23)
24
25import torch
26import torch.distributed as dist
27import torch.distributed.fsdp._traversal_utils as traversal_utils
28import torch.nn as nn
29from torch.distributed._state_dict_utils import _gather_state_dict
30from torch.distributed.distributed_c10d import _get_pg_default_device
31from torch.distributed.fsdp._common_utils import (
32    _apply_to_modules,
33    _FSDPState,
34    _get_module_fsdp_state_if_fully_sharded_module,
35    _get_param_to_fqns,
36    _module_handle,
37    _named_parameters_with_duplicates,
38    clean_tensor_name,
39)
40from torch.distributed.fsdp._debug_utils import SimpleProfiler
41from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle
42from torch.distributed.fsdp._fsdp_extensions import (
43    _ext_chunk_dtensor,
44    _ext_chunk_tensor,
45)
46from torch.distributed.fsdp._runtime_utils import (
47    _lazy_init,
48    _reset_flat_param_grad_info_if_needed,
49)
50from torch.distributed.fsdp.api import (
51    ShardingStrategy,
52    StateDictSettings,
53    StateDictType,
54)
55from torch.distributed.tensor import DTensor, Replicate
56from torch.utils._pytree import tree_map_only
57
58
59if TYPE_CHECKING:
60    from torch.distributed._shard.sharded_tensor import ShardedTensor
61
62
63logger = logging.getLogger(__name__)
64
65
66@dataclass
67class FSDPParamInfo:
68    state: _FSDPState
69    handle: FlatParamHandle
70    param_indices: Dict[str, int]
71    param_requires_grad: List[bool]
72
73
74def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
75    keys = sorted(dictionary.keys())
76    for k in keys:
77        yield k, dictionary[k]
78
79
80@dataclass
81class _ConsolidatedOptimState:
82    """
83    This holds the consolidated optimizer state on the target rank. Positive-
84    dimension tensor state is communicated across ranks, while zero-dimension
85    tensor state and non-tensor state is taken directly from the target rank.
86
87    PyTorch version 1.12 moved to using zero-dimension tensors for scalar
88    values, but user implemented optimizers may still use float (i.e. a
89    non-tensor). Thus, we support both and handle them identically.
90
91    Attributes:
92        tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension
93            tensor state name to the unsharded flat tensor representing the
94            state.
95        zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero-
96            dimension tensor state name to its value.
97        non_tensor_state (Dict[str, Any]): Mapping from non-tensor state
98            name to its value.
99    """
100
101    tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
102    zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict)
103    non_tensor_state: Dict[str, Any] = field(default_factory=dict)
104
105
106class _PosDimTensorInfo(NamedTuple):
107    """
108    Meatadata for positive-dimension tensors used internally for
109    :meth:`scatter_full_optim_state_dict`.
110
111    Attributes:
112        shape (torch.Size): Sharded tensor shape (which is equal to the
113            unsharded tensor shape if the tensor is optimizer state for a
114            non-FSDP parameter and is hence not sharded).
115        dtype (torch.dtype): Data type of the tensor.
116    """
117
118    shape: torch.Size
119    dtype: torch.dtype
120
121
122class _OptimStateKey(NamedTuple):
123    """
124    This represents an optimizer state key that may be used commonly across
125    ranks. It is based on the unflattened parameter names rather than parameter
126    IDs to make it independent of each rank's own optimizer construction.
127    """
128
129    unflat_param_names: Tuple[str, ...]
130    is_fsdp_managed: bool
131
132
133def _unflatten_optim_state(
134    fsdp_param_info: FSDPParamInfo,
135    flat_param_state: Dict[str, Any],
136    to_save: bool,
137    shard_state: bool,
138    cpu_offload: bool,
139) -> List[Dict[str, Any]]:
140    """
141    Unflattens the optimizer state, consisting of the "state" part and the
142    "param_groups" part. Unflattening the "state" part involves consolidating
143    the state on the target rank and remapping from flattened to unflattened
144    parameter IDs, and the "param_groups" part only involves remapping from
145    flattened to unflattened parameter IDs.
146
147    Args:
148        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
149            mapping from FQN to original parameter index.
150        flat_param_state (Dict[str, Any]): Entry for the flat parameter in the
151            "state" part of the optimizer state dict.
152        to_save (bool): Whether to save the state on this rank.
153
154    Returns:
155        List[Dict[str, Any]]: A :class:`list` holding the entries in the
156        "state" part of the optimizer state dict corresponding to the
157        unflattened parameters comprising the flat parameter if on the target
158        rank or an empty :class:`list` otherwise. The final optimizer state
159        dict will need to map these entries using the proper unflattened
160        parameter IDs.
161    """
162    assert (
163        not shard_state or to_save
164    ), "If ``shard_state`` is True, ``to_save`` has to be True."
165    consolidated_state = _communicate_optim_state(
166        fsdp_param_info,
167        flat_param_state,
168    )
169    if to_save:
170        unflat_param_state = _unflatten_communicated_optim_state(
171            fsdp_param_info,
172            consolidated_state,
173            shard_state,
174        )
175        for optim_state in unflat_param_state:
176            # We can't use .items() below cuz we'd run into a concurrent modification error
177            if cpu_offload:
178                for key in list(optim_state.keys()):
179                    state = optim_state[key]
180                    if not isinstance(state, torch.Tensor):
181                        continue
182                    optim_state[key] = state.cpu()
183        return unflat_param_state
184    else:
185        return []
186
187
188def _is_zero_dim_tensor(x: Any) -> bool:
189    return torch.is_tensor(x) and x.dim() == 0
190
191
192def _communicate_optim_state(
193    fsdp_param_info: FSDPParamInfo,
194    flat_param_state: Dict[str, Any],
195) -> _ConsolidatedOptimState:
196    """
197    Communicates the optimizer state for a flat parameter across ranks. All
198    ranks will hold the entire non-sharded optimizer state on GPU.
199
200    If ``N`` is the number of tensor optimizer states in the optimizer state
201    dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1``
202    otherwise (where the plus 1 comes from all-gathering the padding per rank).
203
204    Args:
205        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
206            mapping from FQN to original parameter index.
207        flat_param_state (Dict[str, Any]): The entry in the "state" part of the
208            optimizer state dict corresponding to the flat parameter.
209
210    Returns:
211        ConsolidatedOptimState: Consolidated optimizer state for the target
212        flat parameter.
213    """
214    fsdp_state = fsdp_param_info.state
215    flat_param = fsdp_param_info.handle.flat_param
216    state = _ConsolidatedOptimState()
217    tensor_state, zero_dim_tensor_state, non_tensor_state = (
218        state.tensor_state,
219        state.zero_dim_tensor_state,
220        state.non_tensor_state,
221    )
222
223    for state_name, value in sorted_items(flat_param_state):
224        # Positive-dimension tensor state: communicate across ranks
225        if torch.is_tensor(value) and value.dim() > 0:
226            # If the parameter is not sharded, then neither is the
227            # positive-dimension tensor state, so no need to communicate it --
228            # we take the target rank's value
229            if (
230                fsdp_state.world_size == 1
231                or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
232            ):
233                tensor_state[state_name] = value
234                continue
235            assert (
236                fsdp_state.compute_device is not None
237            ), "compute_device has not been initialized"
238            if value.device.type != fsdp_state.compute_device.type:
239                value = value.to(fsdp_state.compute_device)
240            # Assume that positive-dimension tensor optimizer state
241            # has the same shape as the sharded flat parameter
242            buffer_size = flat_param._full_param_padded.size()  # type: ignore[attr-defined]
243            tensor_buffer = value.new_zeros(*buffer_size)
244            dist.all_gather_into_tensor(
245                tensor_buffer, value, group=fsdp_state.process_group
246            )
247            fsdp_state._device_handle.synchronize()
248            unpadded_numel = cast(
249                nn.Parameter, flat_param._unpadded_unsharded_size
250            ).numel()
251            tensor_state[state_name] = tensor_buffer[:unpadded_numel]
252        # Zero-dimension tensor state and non-tensor state: take this rank's
253        # value directly
254        else:
255            if _is_zero_dim_tensor(value):
256                zero_dim_tensor_state[state_name] = value.detach().clone()
257            else:
258                non_tensor_state[state_name] = value
259    return state
260
261
262def _unflatten_communicated_optim_state(
263    fsdp_param_info: FSDPParamInfo,
264    state: _ConsolidatedOptimState,
265    shard_state: bool,
266) -> List[Dict[str, Any]]:
267    """
268    Unflattens the communicated optimizer state (given by ``tensor_state``,
269    ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat
270    parameter. This should only be called on the target rank.
271
272    Args:
273        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
274            mapping from FQN to original parameter index.
275        state (_ConsolidatedOptimState): Consolidated optimizer state.
276
277    Returns:
278        List[Dict[str, Any]]: A :class:`list` holding the entries in the
279        "state" part of the optimizer state dict corresponding to the
280        unflattened parameters comprising the flat parameter. The final
281        optimizer state dict will need to map these entries using the proper
282        unflattened parameter IDs.
283    """
284    fsdp_state = fsdp_param_info.state
285    handle = fsdp_param_info.handle
286    flat_param = handle.flat_param
287    unflat_param_state: List[Dict[str, Any]] = []
288    flat_param_views: Dict[str, Iterator] = {}
289    num_unflat_params = flat_param._num_params
290    tensor_state, zero_dim_tensor_state, non_tensor_state = (
291        state.tensor_state,
292        state.zero_dim_tensor_state,
293        state.non_tensor_state,
294    )
295
296    for _ in range(num_unflat_params):
297        unflat_state_param = {}
298        # Add positive-dimension tensor state: unflatten with views
299        for state_name, flat_tensor in sorted_items(tensor_state):
300            views_generated = state_name in flat_param_views
301            if not views_generated:
302                views = handle._get_unflat_views(flat_tensor)
303                flat_param_views[state_name] = views
304            else:
305                views = flat_param_views[state_name]
306            optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views)
307            if shard_state:
308                osd_config = fsdp_state._optim_state_dict_config
309                if getattr(osd_config, "_use_dtensor", False):
310                    assert fsdp_state._device_mesh is not None
311                    optim_state = _ext_chunk_dtensor(
312                        optim_state,
313                        fsdp_state.rank,
314                        fsdp_state._device_mesh,
315                        fsdp_state._fsdp_extension,
316                    )
317                else:
318                    assert fsdp_state.process_group is not None
319                    optim_state = _ext_chunk_tensor(
320                        optim_state,
321                        fsdp_state.rank,
322                        fsdp_state.world_size,
323                        fsdp_state._device_handle.device_count(),
324                        fsdp_state.process_group,
325                        fsdp_state._fsdp_extension,
326                    )
327            unflat_state_param[state_name] = optim_state
328
329        # Add zero-dimension tensor state: take the target rank's value
330        for state_name, zero_dim_tensor in sorted_items(zero_dim_tensor_state):
331            unflat_state_param[state_name] = zero_dim_tensor
332        # Add non-tensor state: take the target rank's value
333        for state_name, non_tensor in sorted_items(non_tensor_state):
334            unflat_state_param[state_name] = non_tensor
335        unflat_param_state.append(unflat_state_param)
336    return unflat_param_state
337
338
339def _broadcast_processed_state(
340    fsdp_state: _FSDPState,
341    optim_state: Dict[str, Any],
342    group: Optional[dist.ProcessGroup],
343) -> Dict[str, Any]:
344    objects: List[Any] = [None]
345    if dist.get_rank(group) == 0:
346        objects[0] = tree_map_only(
347            torch.Tensor,
348            lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype),  # type: ignore[union-attr]
349            optim_state,
350        )
351    dist.broadcast_object_list(objects, src=0, group=group)
352    if dist.get_rank(group) == 0:
353        return optim_state
354    else:
355        return objects[0]
356
357
358def _broadcast_state(
359    fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup]
360) -> Any:
361    if dist.get_rank(group) == 0:
362        if not isinstance(state, torch.Tensor) or state.dim() == 0:
363            return state
364        tensor = state.to(fsdp_state.compute_device)
365    else:
366        if isinstance(state, torch.Tensor):
367            assert state.dim() == 0, (
368                "For non-zero ranks, a tensor state should have zero dimension, "
369                "but got the state with shape {state.shape()}."
370            )
371            return state
372        elif not isinstance(state, _PosDimTensorInfo):
373            return state
374        tensor = torch.zeros(
375            state.shape, dtype=state.dtype, device=fsdp_state.compute_device
376        )
377    dist.broadcast(tensor, src=0, group=group)
378    return tensor
379
380
381def _shard_orig_param_state(
382    fsdp_param_info: FSDPParamInfo,
383    fqn: str,
384    optim_state: Dict[str, Any],
385) -> Dict[str, Any]:
386    """
387    Shard the optimizer state for the original parameter with the name ``fqn``.
388    This API should only be used when ``use_orig_params`` is True.
389    """
390    if not optim_state:
391        return {}
392    fsdp_state = fsdp_param_info.state
393    flat_param = fsdp_param_info.handle.flat_param
394    param_idx = fsdp_param_info.param_indices[fqn]
395    shard_param_info = flat_param._shard_param_infos[param_idx]  # type: ignore[attr-defined]
396    optim_state = _gather_state_dict(
397        optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device
398    )
399    if not shard_param_info.in_shard:
400        return {}
401    # Flatten and shard the state.
402    new_optim_state: Dict[str, Any] = {}
403    intra_param_start_idx = shard_param_info.intra_param_start_idx
404    intra_param_end_idx = shard_param_info.intra_param_end_idx
405    for state_name, value in optim_state.items():
406        if (
407            torch.is_tensor(value)
408            and value.dim() > 0
409            and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
410        ):
411            value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone()  # type: ignore[operator]
412        new_optim_state[state_name] = value
413    return new_optim_state
414
415
416def _flatten_optim_state_dict(
417    optim_state_dict: Dict[str, Any],
418    model: nn.Module,
419    use_orig_params: bool = False,
420    optim: Optional[torch.optim.Optimizer] = None,
421    rank0_only: bool = False,
422    group: Optional[dist.ProcessGroup] = None,
423) -> Dict[str, Any]:
424    """
425    Flattens the full optimizer state dict, still keying by unflattened parameter
426    names.
427
428    If ``use_orig_params`` is True, each rank will have all FSDP-managed
429    parameters but some of these parameters may be empty due to the sharding.
430    For a regular optim.Optimizer, states for those empty parameters will
431    not be initialized. So, when aggregating the FQNs across ranks, no assert
432    will be raised on a rank even if it does not have all the states -- it is
433    valid and FSDP know how to aggregate them. However, FSDP has to ignore
434    handling those parameters that are not managed by FSDP and do not exist on
435    the local rank -- it is managed by other parallelism and FSDP does not
436    know ho to handle/aggregate them.
437
438    Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to
439    flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require
440    all the states even if the corresponding parameters are empty. To this end,
441    ``optim`` will be used to to get the initial state of the empty parameters.
442    ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or
443    NamedOptimizer.
444
445    Returns:
446        Dict[str, Any]: The flattened optimizer state dict.
447    """
448    SimpleProfiler.reset()
449
450    unflat_osd = optim_state_dict
451    if "state" not in unflat_osd and not rank0_only:
452        raise ValueError(
453            '`optim_state_dict` must have the keys "state"'
454            "to be a valid optimizer state dict"
455        )
456    param_to_fqns = _get_param_to_fqns(model)
457    fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
458    fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state
459
460    # Broadcast unflat_osd without non-scalar tensor if rank0_only is True.
461    if rank0_only:
462        unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group)
463
464    # Construct the "state" part
465    flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {}
466    unflat_osd_state = unflat_osd["state"]
467    all_state_keys = set(unflat_osd_state.keys())
468
469    for param, fqns in param_to_fqns.items():
470        fqn = fqns[0]
471        if fqn not in unflat_osd_state:
472            continue
473        all_state_keys.difference_update(fqns)
474
475        if rank0_only:
476            for fqn in fqns:
477                if not unflat_osd_state[fqn]:
478                    continue
479                for state_name in unflat_osd_state[fqn].keys():
480                    unflat_osd_state[fqn][state_name] = _broadcast_state(
481                        fsdp_state, unflat_osd_state[fqn][state_name], group=group
482                    )
483            fqn = fqns[0]
484        if fqn in fqn_to_fsdp_param_info:
485            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
486            if use_orig_params:
487                with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
488                    flat_state = _shard_orig_param_state(
489                        fsdp_param_info,
490                        fqn,
491                        unflat_osd_state[fqn],
492                    )
493            else:
494                flat_state = _flatten_optim_state(
495                    fsdp_param_info,
496                    unflat_osd_state,
497                    fqns,
498                )
499            key = _OptimStateKey(tuple(fqns), True)
500            # Only include non-empty states since as expected by
501            # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer
502            # or NamedOptimizer.
503            if flat_state:
504                flat_osd_state[key] = flat_state
505            elif use_orig_params:
506                assert (
507                    len(fqns) == 1
508                ), f"use_orig_params is True but there are multiple FQNs, {fqns}."
509                if optim is not None:  # NamedOptimizer or KeyedOptimizer case.
510                    state = optim.state.get(param, None)  # type: ignore[call-overload]
511                    if state is not None:
512                        flat_osd_state[key] = copy.deepcopy(state)
513                    else:
514                        warnings.warn(
515                            f"optim_state[{key}] is not on rank{fsdp_state.rank}."
516                        )
517
518            else:
519                raise RuntimeError(
520                    f"The state of {key} is empty. This should happen when "
521                    "use_orig_params=True."
522                )
523        else:  # do not flatten non-FSDP parameters' states
524            assert len(fqns) == 1
525            key = _OptimStateKey(tuple(fqns), False)
526            flat_osd_state[key] = copy.copy(unflat_osd_state[fqn])
527
528        if rank0_only:
529            for fqn in fqns:
530                if not unflat_osd_state[fqn]:
531                    continue
532                for state_name, param_state in list(unflat_osd_state[fqn].items()):
533                    if fsdp_state.rank > 0:
534                        # Deference the tensor so that PyTorch can collect the memory.
535                        del unflat_osd_state[fqn][state_name]
536                    else:
537                        # Move the tensor in the original osd back to CPU to make the
538                        # original osd unaffected.
539                        unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][
540                            state_name
541                        ].cpu()
542
543    # Handle user-defined state, states that are not associated with parameters.
544    for key in all_state_keys:
545        user_state = unflat_osd_state[key]
546        if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params:
547            user_state = _broadcast_state(fsdp_state, user_state, group=group)
548        flat_osd_state[key] = copy.copy(user_state)
549
550    SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ")
551    # Construct the "param_groups" part -- copy as is since it will be
552    # rekeyed later according to the target rank's optimizer
553    # Only copy param_groups if it exists in unflat_osd
554    if "param_groups" in unflat_osd:
555        flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"])
556        return {"state": flat_osd_state, "param_groups": flat_osd_param_groups}
557    else:
558        return {"state": flat_osd_state}
559
560
561def _flatten_optim_state(
562    fsdp_param_info: FSDPParamInfo,
563    unflat_osd_state: Dict[str, Dict[str, Any]],
564    unflat_param_names: List[str],
565) -> Dict[str, Any]:
566    """
567    Flattens the optimizer state in ``full_optim_state_dict`` for a single
568    flat parameter in ``fsdp_param_info`` corresponding to the unflattened
569    parameter names in ``unflat_param_names``.
570
571    Args:
572        fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a
573            mapping from FQN to original parameter index.
574        unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the
575            optimizer state dict corresponding to the unflattened parameters.
576        unflat_param_names (List[str]): A :class:`list` of unflattened
577            parameter names corresponding to the flat parameter ``flat_param``.
578
579    Returns:
580        Dict[str, Any]: A :class:`dict` mapping state names to their values for
581        a particular flat parameter. The sharded optimizer state dict's "state"
582        part will map a key to this returned value.
583    """
584    fsdp_state = fsdp_param_info.state
585    handle = fsdp_param_info.handle
586    flat_param = handle.flat_param
587    num_unflat_params = len(unflat_param_names)
588    assert num_unflat_params > 0, (
589        "Expects at least one unflattened parameter corresponding to the "
590        "flat parameter"
591    )
592    unflat_param_shapes = flat_param._shapes
593    num_unflat_param_shapes = len(unflat_param_shapes)
594    assert (
595        num_unflat_params == num_unflat_param_shapes
596    ), f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
597
598    # Check if these unflattened parameters have any optimizer state
599    has_state = [
600        bool(unflat_param_name in unflat_osd_state)
601        for unflat_param_name in unflat_param_names
602    ]
603    # If none of the unflattened parameters comprising this flat parameter have
604    # any state, then we do not want an entry in the optimizer state dict
605    if not any(has_state):
606        return {}  # no need to flatten any state
607    # There may still be some unflattened parameters with state and some
608    # without
609    unflat_param_states = [
610        _gather_state_dict(
611            unflat_osd_state[unflat_param_name],
612            pg=fsdp_state.process_group,
613            device=fsdp_state.compute_device,
614        )
615        if unflat_param_name in unflat_osd_state
616        else None
617        for unflat_param_name in unflat_param_names
618    ]
619    # Check that the unflattened parameters have the same state names
620    state_names = None
621    for unflat_param_state in unflat_param_states:
622        if unflat_param_state is None:
623            continue
624        if state_names is None:
625            state_names = set(unflat_param_state.keys())
626        else:
627            if state_names != set(unflat_param_state.keys()):
628                raise ValueError(
629                    "Differing optimizer state names for the unflattened "
630                    f"parameters: {unflat_param_names}"
631                )
632    assert state_names is not None
633
634    # Flatten the state
635    flat_state: Dict[str, Any] = {}
636    for state_name in state_names:
637        state_values = [
638            unflat_param_state[state_name] if unflat_param_state is not None else None
639            for unflat_param_state in unflat_param_states
640        ]
641        non_none_state_values = [v for v in state_values if v is not None]
642        # If all ranks have None, this is a None value
643        if not non_none_state_values:
644            flat_state[state_name] = None
645            continue
646        are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True
647        for v in non_none_state_values:
648            are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
649            are_zero_dim_tensors &= _is_zero_dim_tensor(v)
650            are_non_tensors &= not torch.is_tensor(v)
651        types = {type(v) for v in non_none_state_values}
652        if len(types) != 1 or not (
653            are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
654        ):
655            raise ValueError(
656                f"Differing optimizer state types for state {state_name}, "
657                f"values {non_none_state_values}, and unflattened parameter "
658                f"names {unflat_param_names}"
659            )
660        if are_pos_dim_tensors:
661            flat_tensor = _flatten_tensor_optim_state(
662                state_name,
663                state_values,
664                unflat_param_names,
665                unflat_param_shapes,
666                handle,
667            )
668            # Shard the flattened tensor immediately to minimize max memory
669            # usage
670            if (
671                fsdp_state.world_size != 1
672                and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
673            ):
674                sharded_flat_tensor, _ = FlatParamHandle._get_shard(
675                    flat_tensor,
676                    fsdp_state.rank,
677                    fsdp_state.world_size,
678                )
679            else:
680                sharded_flat_tensor = flat_tensor
681            flat_state[state_name] = sharded_flat_tensor
682        elif are_zero_dim_tensors:
683            flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
684                state_name,
685                state_values,
686                unflat_param_names,
687            )
688        else:
689            assert are_non_tensors
690            flat_state[state_name] = _flatten_non_tensor_optim_state(
691                state_name,
692                state_values,
693                unflat_param_names,
694            )
695
696    return flat_state
697
698
699def _flatten_tensor_optim_state(
700    state_name: str,
701    pos_dim_tensors: List[torch.Tensor],
702    unflat_param_names: List[str],
703    unflat_param_shapes: Sequence[torch.Size],
704    handle: FlatParamHandle,
705) -> torch.Tensor:
706    """
707    Flattens the positive-dimension tensor optimizer state given by the values
708    ``tensors`` for the state ``state_name`` for a single flat parameter
709    from ``handle`` corresponding to the unflattened parameter names
710    ``unflat_param_names`` and unflatted parameter shapes
711    ``unflat_param_shapes``. This flattens each unflattened parameter's tensor
712    state into one tensor.
713
714    NOTE: We use zero tensors for any unflattened parameters without state
715    since some value is required to fill those entries. This assumes that the
716    zero tensor is mathematically equivalent to having no state, which is true
717    for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all
718    optimizers.
719
720    Args:
721        state_name (str): Optimizer state name.
722        pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor
723            optimizer state values for the unflattened parameters corresponding
724            to the single flat parameter.
725        unflat_param_names (List[str]): A :class:`list` of unflattened
726            parameter names corresponding to the single flat parameter.
727        unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes
728            corresponding to the single flat parameter.
729        handle (FlatParamHandle): The flat parameter's handle.
730
731    Returns:
732        torch.Tensor: A flat tensor containing the optimizer state
733        corresponding to ``state_name`` constructed by concatenating the
734        unflattened parameter tensor states in ``pos_dim_tensors`` (using zero
735        tensors for any unflattened parameters without the state).
736    """
737    flat_param = handle.flat_param
738    non_none_tensors = [t for t in pos_dim_tensors if t is not None]
739    # Check that all are tensors with the same dtype
740    dtypes = {t.dtype for t in non_none_tensors}
741    if len(dtypes) != 1:
742        raise ValueError(
743            "All unflattened parameters comprising a single flat "
744            "parameter must have positive-dimension tensor state with the "
745            f"same dtype but got dtypes {dtypes} for state {state_name} and "
746            f"unflattened parameter names {unflat_param_names}"
747        )
748    dtype = next(iter(dtypes))
749    # Check that each tensor state matches its parameter's shape
750    for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes):
751        if tensor is None and len(shape) == 0:
752            raise ValueError("Flattening a zero-dimension parameter is not supported")
753        elif tensor is not None and tensor.shape != shape:
754            raise ValueError(
755                "Tensor optimizer state does not have same shape as its "
756                f"parameter: {tensor.shape} {shape}"
757            )
758    # Flatten the tensor states: we do not need to add any right-hand-side
759    # padding since the flat optimizer state tensor is sharded via
760    # `_get_shard()`, which pads the shard as needed (just like for the flat
761    # parameter)
762    cpu_device = torch.device("cpu")
763    tensors_to_flatten = [
764        torch.flatten(state_value.to(cpu_device))
765        if state_value is not None
766        else torch.flatten(
767            torch.zeros(
768                size=shape,
769                dtype=dtype,
770                device=cpu_device,
771            )
772        )
773        for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes)
774    ]
775    flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
776    flat_param_shape = flat_param._unpadded_unsharded_size  # type: ignore[attr-defined]
777    assert flat_tensor.shape == flat_param_shape, (
778        f"tensor optim state: {flat_tensor.shape} "
779        f"flat parameter: {flat_param_shape}"
780    )
781    return flat_tensor
782
783
784def _flatten_zero_dim_tensor_optim_state(
785    state_name: str,
786    zero_dim_tensors: List[torch.Tensor],
787    unflat_param_names: List[str],
788) -> torch.Tensor:
789    """
790    Flattens the zero-dimension tensor optimizer state given by the values
791    ``zero_dim_tensors`` for the state ``state_name`` for a single flat
792    parameter corresponding to the unflattened parameter names
793    ``unflat_param_names`` by enforcing that all tensors are the same and using
794    that common value.
795
796    NOTE: The requirement that the tensors are the same across all unflattened
797    parameters comprising the flat parameter is needed to maintain the
798    invariant that FSDP performs the same computation as its non-sharded
799    equivalent. This means that none of the unflattened parameters can be
800    missing this state since imposing a value may differ from having no value.
801    For example, for Adam's "step", no value means maximum bias correction,
802    while having some positive value means less bias correction.
803
804    Args:
805        state_name (str): Optimizer state name.
806        zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state
807            for the unflattened parameters corresponding to the single
808            flat parameter.
809        unflat_param_names (List[str]): A :class:`list` of unflattened
810            parameter names corresponding to the single flat parameter.
811
812    Returns:
813        torch.Tensor: A zero-dimensional tensor giving the value of the state
814        ``state_name`` for all unflattened parameters corresponding to the
815        names ``unflat_param_names``.
816    """
817    non_none_tensors = [t for t in zero_dim_tensors if t is not None]
818    # Enforce that all have the same value and dtype
819    values_set = {t.item() if t is not None else None for t in zero_dim_tensors}
820    dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors}
821    if (
822        len(non_none_tensors) != len(zero_dim_tensors)
823        or len(values_set) != 1
824        or len(dtypes) != 1
825    ):
826        raise ValueError(
827            "All unflattened parameters comprising a single flat "
828            "parameter must have scalar state with the same value and dtype "
829            f"but got values {values_set} and dtypes {dtypes} for state "
830            f"{state_name} and unflattened parameter names "
831            f"{unflat_param_names}"
832        )
833    value = next(iter(values_set))
834    dtype = next(iter(dtypes))
835    return torch.tensor(value, dtype=dtype, device=torch.device("cpu"))
836
837
838def _flatten_non_tensor_optim_state(
839    state_name: str,
840    non_tensors: List[Any],
841    unflat_param_names: List[str],
842) -> Any:
843    """
844    Flattens the non-tensor optimizer state given by the values ``non_tensors``
845    for the state ``state_name`` for a single flat parameter corresponding
846    to the unflattened parameter names ``unflat_param_names`` by enforcing that
847    all values are the same and using that common value.
848
849    See the note in :func:`_flatten_zero_dim_tensor_optim_state`.
850
851    Args:
852        state_name (str): Optimizer state name.
853        non_tensors (List[Any]): Non-tensor optimizer state for the unflattened
854            parameters corresponding to the single flat parameter.
855        unflat_param_names (List[str]): A :class:`list` of unflattened
856            parameter names corresponding to the single flat parameter.
857
858    Returns:
859        Any: A non-tensor giving the value of the state ``state_name`` for all
860        unflattened parameters corresponding to the names
861        ``unflat_param_names``.
862    """
863    non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
864    # Enforce that all have the same value (same type already checked)
865    non_tensor_set = set(non_tensors)
866    if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1:
867        raise ValueError(
868            "All unflattened parameters comprising a single flat "
869            "parameter must have scalar state with the same value and dtype "
870            f"but got values {non_tensor_set} for state {state_name} and  "
871            f"unflattened parameter names {unflat_param_names}"
872        )
873    non_tensor = next(iter(non_tensor_set))
874    return non_tensor
875
876
877def _rekey_sharded_optim_state_dict(
878    sharded_osd: Dict[str, Any],
879    model: nn.Module,
880    optim: torch.optim.Optimizer,
881    optim_input: Optional[
882        Union[
883            List[Dict[str, Any]],
884            Iterable[nn.Parameter],
885        ]
886    ],
887    using_optim_input: bool,
888    is_named_optimizer: bool = False,
889) -> Dict[str, Any]:
890    """
891    Rekeys the optimizer state dict from unflattened parameter names to flat
892    parameter IDs according to the calling rank's ``optim``, which may be
893    different across ranks. In particular, the unflattened parameter names are
894    represented as :class:`_OptimStateKey` s.
895    """
896    param_to_fqns = _get_param_to_fqns(model)
897    flat_param_to_fqn = _get_flat_param_to_fqn(model)
898    param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast(
899        Dict[nn.Parameter, Union[int, str]],
900        (
901            _get_param_to_param_id_from_optim_input(model, optim_input)
902            if using_optim_input
903            else _get_param_to_param_key(
904                optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
905            )
906        ),
907    )
908    # All parameter keys in `param_to_param_key` should be in
909    # `param_to_fqns` -- strict inequality follows when not all parameters are
910    # passed to the optimizer
911    assert len(param_to_param_key) <= len(param_to_fqns)
912
913    unflat_param_names_to_flat_param_key: Dict[
914        Tuple[str, ...], Union[int, str]
915    ] = {}  # for "state"
916    unflat_param_name_to_flat_param_key: Dict[
917        str, Union[int, str]
918    ] = {}  # for "param_groups"
919    for param, unflat_param_names in param_to_fqns.items():
920        if param not in param_to_param_key:
921            # This parameter was not passed to the optimizer
922            continue
923        flat_param_key = param_to_param_key[param]
924        unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key
925        for unflat_param_name in unflat_param_names:
926            unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key
927
928    sharded_osd_state = sharded_osd["state"]
929    rekeyed_osd_state: Dict[Union[str, int], Any] = {}
930    for key, param_state in sharded_osd_state.items():
931        if isinstance(key, str):
932            rekeyed_osd_state[key] = param_state
933            continue
934        flat_param_key = unflat_param_names_to_flat_param_key.get(
935            key.unflat_param_names, key.unflat_param_names
936        )
937        rekeyed_osd_state[flat_param_key] = param_state
938
939    # Only process param_groups if it exists in sharded_osd
940    if "param_groups" in sharded_osd:
941        rekeyed_osd_param_groups: List[Dict[str, Any]] = []
942        for unflat_param_group in sharded_osd["param_groups"]:
943            flat_param_group = copy.deepcopy(unflat_param_group)
944            flat_param_keys = sorted(
945                {
946                    unflat_param_name_to_flat_param_key[unflat_param_name]
947                    for unflat_param_name in unflat_param_group["params"]
948                }
949            )
950            flat_param_group["params"] = flat_param_keys
951            rekeyed_osd_param_groups.append(flat_param_group)
952        return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups}
953    else:
954        return {"state": rekeyed_osd_state}
955
956
957def _get_param_id_to_param_from_optim_input(
958    model: nn.Module,
959    optim_input: Optional[
960        Union[
961            List[Dict[str, Any]],
962            Iterable[nn.Parameter],
963        ]
964    ] = None,
965) -> Dict[int, nn.Parameter]:
966    """
967    Constructs a mapping from parameter IDs to parameters. This may be used
968    both for models with ``FlatParameter`` s and without.
969
970    NOTE: This method is only preserved for backward compatibility. The method
971    :meth:`_get_param_key_to_param` is the preferred code path that does not
972    rely on ``optim_input``.
973
974    NOTE: We critically assume that, whether the optimizer input is a list of
975    parameters or a list of parameter groups, :class:`torch.optim.Optimizer`
976    enumerates the parameter IDs in order. In other words, for a parameter list
977    input, the parameter IDs should be in that list order, and for a parameter
978    groups input, the parameter IDs should be in order within each parameter
979    group and in order across parameter groups.
980
981    Args:
982        model (nn.Module): Model whose parameters are passed into the
983            optimizer.
984        optim_input (Optional[Union[List[Dict[str, Any]],
985        Iterable[nn.Parameter]]]): Input passed into the optimizer
986            representing either a :class:`list` of parameter groups or an
987            iterable of parameters; if ``None``, then this method assumes the
988            input was ``model.parameters()``. (Default: ``None``)
989
990    Returns:
991        List[nn.Parameter]: Mapping from parameter IDs to parameters,
992        where the parameter ID is implicitly the index in the :class:`list`.
993    """
994    # Assume the standard case of passing `model.parameters()` to the optimizer
995    # if `optim_input` is not specified
996    if optim_input is None:
997        return dict(enumerate(model.parameters()))
998    try:
999        params = cast(List[nn.Parameter], list(optim_input))
1000    except TypeError as e:
1001        raise TypeError(
1002            "Optimizer input should be an iterable of Tensors or dicts, "
1003            f"but got {optim_input}"
1004        ) from e
1005    if len(params) == 0:
1006        raise ValueError("Optimizer input should not be empty")
1007
1008    # Check if the optimizer input represents tensors or parameter groups
1009    all_tensors = True
1010    all_dicts = True
1011    for param in params:
1012        all_tensors &= isinstance(param, torch.Tensor)
1013        all_dicts &= isinstance(param, dict)
1014    if not all_tensors and not all_dicts:
1015        raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
1016    if all_tensors:
1017        return dict(enumerate(params))
1018    assert all_dicts
1019    param_id_to_param: List[nn.Parameter] = []
1020    for param_group in params:
1021        has_params_key = "params" in param_group  # type: ignore[operator]
1022        assert has_params_key, (
1023            'A parameter group should map "params" to a list of the '
1024            "parameters in the group"
1025        )
1026        # Implicitly map `flat_param_id` (current length of the list) to
1027        # `param`
1028        param_id_to_param.extend(param_group["params"])  # type: ignore[index]
1029    return dict(enumerate(param_id_to_param))
1030
1031
1032def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]:
1033    """
1034    Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes
1035    from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical"
1036    because ``FlatParameter``  s do not come from the original module but are
1037    registered only after FSDP has been applied. This function returns the FSDP-given
1038    name for the ``FlatParameter`` (usually module._flat_param) as opposed to the
1039    canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``).
1040
1041    Consequently, this function will only return a non-empty mapping if FSDP was
1042    applied with ``use_orig_params=False`` as, otherwise, the original parameters
1043    are used within the module and there would be no ``FlatParameter`` s in the module.
1044
1045    """
1046
1047    def module_fn(module, prefix, tree_level, flat_param_to_fqn):
1048        for param_name, param in _named_parameters_with_duplicates(
1049            module, recurse=False
1050        ):
1051            if not isinstance(param, FlatParameter):
1052                continue
1053            fqn = clean_tensor_name(prefix + param_name)
1054            flat_param_to_fqn[param] = fqn
1055
1056    def return_fn(flat_param_to_fqn):
1057        return flat_param_to_fqn
1058
1059    flat_param_to_fqn_ret: Dict[FlatParameter, str] = {}
1060    return _apply_to_modules(
1061        model,
1062        module_fn,
1063        return_fn,
1064        [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
1065        flat_param_to_fqn_ret,
1066    )
1067
1068
1069def _get_param_key_to_param(
1070    optim: torch.optim.Optimizer,
1071    model: Optional[nn.Module] = None,
1072    is_named_optimizer: bool = False,
1073    param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
1074    flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
1075) -> Dict[Union[int, str], nn.Parameter]:
1076    """
1077    Constructs a mapping from parameter keys to parameters. For the regular
1078    optimizers, the keys are parameter IDs. For NamedOptimizer, the keys
1079    are FQNs. This API may be used both for models with ``FlatParameter`` s and
1080    without.
1081    """
1082    clean_fqn_to_curr_fqn: Dict[str, str] = {}
1083    if is_named_optimizer:
1084        assert (
1085            param_to_fqns is not None and flat_param_to_fqn is not None
1086        ), "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
1087        assert model is not None
1088        for key, _ in _named_parameters_with_duplicates(model):
1089            clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
1090
1091    param_key_to_param: Dict[Union[str, int], nn.Parameter] = {}
1092    pid = 0
1093    for param_group in optim.param_groups:
1094        if is_named_optimizer:
1095            for param in param_group["params"]:
1096                assert flat_param_to_fqn is not None
1097                if param in flat_param_to_fqn:
1098                    # FlatParameter case
1099                    key = flat_param_to_fqn[param]
1100                else:
1101                    assert param_to_fqns is not None
1102                    # use_orig_params case
1103                    assert len(param_to_fqns[param]) == 1
1104                    key = param_to_fqns[param][0]
1105                try:
1106                    key = clean_fqn_to_curr_fqn[key]
1107                except KeyError as e:
1108                    raise KeyError(
1109                        f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}."
1110                    ) from e
1111                param_key_to_param[key] = param
1112        else:
1113            for param in param_group["params"]:
1114                param_key_to_param[pid] = param
1115                pid += 1
1116
1117    return param_key_to_param
1118
1119
1120def _get_param_to_param_key(
1121    optim: torch.optim.Optimizer,
1122    model: Optional[nn.Module] = None,
1123    is_named_optimizer: bool = False,
1124    param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None,
1125    flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None,
1126) -> Dict[nn.Parameter, Union[int, str]]:
1127    """
1128    Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API
1129    only supports the case where `optim` is a regular optimizer, not NamedOptimizer.
1130    So the parameter keys will be parameter ids.
1131    """
1132    param_id_to_param = _get_param_key_to_param(
1133        optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
1134    )
1135    return {param: param_id for param_id, param in param_id_to_param.items()}
1136
1137
1138def _get_param_to_param_id_from_optim_input(
1139    model: nn.Module,
1140    optim_input: Optional[
1141        Union[
1142            List[Dict[str, Any]],
1143            Iterable[nn.Parameter],
1144        ]
1145    ] = None,
1146) -> Dict[nn.Parameter, int]:
1147    """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`."""
1148    param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input)
1149    return {param: param_id for param_id, param in param_id_to_param.items()}
1150
1151
1152def _check_missing_keys_on_rank(
1153    r0_optim_state_keys: List[_OptimStateKey],
1154    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]],
1155    param_key_to_param: Dict[Union[str, int], nn.Parameter],
1156    group: Optional[dist.ProcessGroup],
1157) -> None:
1158    # Ensure that all ranks have at least the optimizer states needed by
1159    # rank 0's optimizer
1160    missing_keys: List[_OptimStateKey] = []
1161    for r0_optim_state_key in r0_optim_state_keys:
1162        if r0_optim_state_key not in optim_state_key_to_param_key:
1163            # A parameter from rank 0's optimizer does not exist for this
1164            # rank's optimizer
1165            missing_keys.append(r0_optim_state_key)
1166            continue
1167        param_key = optim_state_key_to_param_key[r0_optim_state_key]
1168        if isinstance(param_key, int):
1169            assert param_key >= 0 and param_key < len(
1170                param_key_to_param
1171            ), "Check the `param_key_to_param` construction"
1172    # We cannot use FSDPState.compute_device as this API is a global view.
1173    device = _get_pg_default_device(group)
1174    num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
1175    dist.all_reduce(num_missing, group=group)
1176    if num_missing.item() > 0:
1177        obj_list = [None for _ in range(dist.get_world_size(group))]
1178        dist.all_gather_object(obj_list, missing_keys, group=group)
1179        error_msg = (
1180            "FSDP currently requires each rank to have at least the "
1181            "optimizer states needed by rank 0's optimizer but some ranks "
1182            "are missing some of those states"
1183        )
1184        for rank, keys in enumerate(obj_list):
1185            keys = cast(List[_OptimStateKey], keys)
1186            if len(keys) > 0:
1187                error_msg += (
1188                    f"\nRank {rank} is missing states for the parameters: "
1189                    f"{[key.unflat_param_names for key in keys]}"
1190                )
1191        raise RuntimeError(error_msg)
1192
1193
1194def _map_param_key_to_optim_keys(
1195    optim_state_dict: Dict[str, Any],
1196    group: Optional[dist.ProcessGroup],
1197    param_key_to_param: Dict[Union[int, str], nn.Parameter],
1198    param_to_fqns: Dict[nn.Parameter, List[str]],
1199    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
1200    merge_keys: bool = False,
1201) -> Tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]:
1202    """
1203    Construct the local mapping between the ``_OptimStateKey`` and parameter keys
1204    and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0
1205    must contain all the ``_OptimStateKey``, an exception will be raised otherwise.
1206    Note that ``merge_keys`` should equal to ``use_orig_params``.
1207    """
1208    rank = dist.get_rank(group)
1209    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {}  # local
1210    all_optim_state_keys: List[_OptimStateKey] = []
1211
1212    for param_key, param in param_key_to_param.items():
1213        # Do not include parameters without state to avoid empty mappings
1214        # just like in normal `torch.optim.Optimizer.state_dict()`
1215        if param_key not in optim_state_dict["state"]:
1216            continue
1217        fqns = param_to_fqns[param]
1218        is_fsdp_managed = isinstance(param, FlatParameter)
1219        if is_fsdp_managed:
1220            assert fqns[0] in fqn_to_fsdp_param_info, (
1221                fqns[0],
1222                list(fqn_to_fsdp_param_info.keys()),
1223            )
1224        is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
1225        optim_state_key = _OptimStateKey(
1226            unflat_param_names=tuple(fqns),
1227            is_fsdp_managed=is_fsdp_managed,
1228        )
1229        if rank == 0 or merge_keys:
1230            all_optim_state_keys.append(optim_state_key)
1231        optim_state_key_to_param_key[optim_state_key] = param_key
1232
1233    if merge_keys:
1234        all_keys: List[List[_OptimStateKey]] = [
1235            [] for _ in range(dist.get_world_size(group))
1236        ]
1237        dist.all_gather_object(all_keys, all_optim_state_keys, group=group)
1238        merge_all_optim_state_keys = [
1239            key for local_keys in all_keys for key in local_keys
1240        ]
1241        all_optim_state_keys = sorted(set(merge_all_optim_state_keys))
1242    else:
1243        key_obj_list: List[Optional[List[_OptimStateKey]]] = (
1244            [all_optim_state_keys] if rank == 0 else [None]
1245        )
1246        dist.broadcast_object_list(key_obj_list, src=0, group=group)
1247        assert key_obj_list[0] is not None
1248        all_optim_state_keys = key_obj_list[0]
1249        _check_missing_keys_on_rank(
1250            all_optim_state_keys,
1251            optim_state_key_to_param_key,
1252            param_key_to_param,
1253            group,
1254        )
1255
1256    return all_optim_state_keys, optim_state_key_to_param_key
1257
1258
1259def _unflatten_param_groups(
1260    state_dict: Dict[str, Any],
1261    param_key_to_param: Dict[Union[int, str], nn.Parameter],
1262    param_to_fqns: Dict[nn.Parameter, List[str]],
1263) -> List[Dict[str, Any]]:
1264    param_groups: List[Dict[str, Any]] = []
1265    for flat_param_group in state_dict["param_groups"]:
1266        unflat_param_group = copy.deepcopy(flat_param_group)
1267        param_group_params = [
1268            param_key_to_param[flat_param_key]
1269            for flat_param_key in flat_param_group["params"]
1270        ]
1271        nested_unflat_param_names = [
1272            param_to_fqns[param] for param in param_group_params
1273        ]
1274        unflat_param_group["params"] = [
1275            unflat_param_name
1276            for unflat_param_names in nested_unflat_param_names
1277            for unflat_param_name in unflat_param_names
1278        ]  # flatten the list of lists
1279        param_groups.append(unflat_param_group)
1280    return param_groups
1281
1282
1283def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool:
1284    """
1285    Returns whether the state_dict is from a NamedOptimizer.
1286    This function checks that the keys in the state_dict['state'] are strings
1287    (which usually are FQNs) versus integers (which usually refer to param_ids
1288    from a vanilla torch.optim.Optimizer).
1289    """
1290    state = optim_state_dict.get("state", None)
1291    if not state:
1292        # If we cannot find a state, assume it is not NamedOptimizer as
1293        # NamedOptimizer has eager initialization.
1294        return False
1295    try:
1296        key = next(iter(state.keys()))
1297    except Exception as e:
1298        raise Exception(optim_state_dict) from e  # noqa: TRY002
1299    return isinstance(key, str)
1300
1301
1302@dataclass
1303class StateInfo:
1304    # The key of these dictionaries are the state name, e.g., `exp_avg`.
1305    tensors: Dict[str, _PosDimTensorInfo]
1306    scalar_tensors: Dict[str, torch.Tensor]
1307    non_tensors: Dict[str, Any]
1308
1309
1310def _allgather_state_info(
1311    fsdp_state: _FSDPState,
1312    input_states: Dict[str, Any],
1313) -> List[Dict[str, StateInfo]]:
1314    """
1315    Given the ``input_states``, allgather StateInfo for each state. The function
1316    uses all_gather_object to gather StateInfo so no GPU tensors are sent.
1317    """
1318
1319    processed_state_dict: Dict[str, StateInfo] = {}
1320    gathered_state_info: List[Dict[str, StateInfo]] = [
1321        {} for _ in range(fsdp_state.world_size)
1322    ]
1323
1324    for fqn, optim_state in input_states.items():
1325        # Allgather the scalar tensor state, non-tensor states and tensors metadata.
1326        processed_state = StateInfo({}, {}, {})
1327        for state_name, value in sorted_items(optim_state):
1328            if torch.is_tensor(value):
1329                if value.dim() == 0:
1330                    # Ensure that `step` is on CPU.
1331                    processed_state.scalar_tensors[state_name] = value.cpu()
1332                else:
1333                    processed_state.tensors[state_name] = _PosDimTensorInfo(
1334                        value.shape, value.dtype
1335                    )
1336            else:
1337                processed_state.non_tensors[state_name] = value
1338        processed_state_dict[fqn] = processed_state
1339    dist.all_gather_object(
1340        gathered_state_info,
1341        processed_state_dict,
1342        group=fsdp_state.process_group,
1343    )
1344    return gathered_state_info
1345
1346
1347def _convert_all_state_info(
1348    fsdp_param_info: FSDPParamInfo,
1349    gathered_state_info: List[Dict[str, StateInfo]],
1350    input_states: Dict[str, Any],
1351    output_states: Dict[str, Dict[str, Any]],
1352) -> Tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]:
1353    """
1354    Given the ``gathered_state_info`` and ``input_states``, the API converted
1355    the StateInfo into the original state if the state is not a non-scalar
1356    tensor. For a multi-dimensional tensor, the local state will be stored in
1357    ``state_buffer`` in a correct order for later allgather purpose.
1358    """
1359
1360    state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {}
1361
1362    for fqn, gathered_state in output_states.items():
1363        state_info = [s[fqn] for s in gathered_state_info]
1364        all_tensor_states = sorted(
1365            {n for state in state_info for n in state.tensors.keys()}
1366        )
1367        empty_ranks: Set[int] = set()
1368        dtype: Optional[torch.dtype] = None
1369        # First check all the non-scalar states and get the information of
1370        # states on each rank.
1371        for state_name in all_tensor_states:
1372            numels = []
1373            _empty_ranks: Set[int] = set()
1374            for rank, object_state in enumerate(state_info):
1375                numels.append(0)
1376                info = object_state.tensors.get(state_name, None)
1377                if info is not None:
1378                    numels[-1] = info.shape.numel()
1379                    if not dtype:
1380                        dtype = info.dtype
1381                    else:
1382                        assert dtype == info.dtype
1383                if numels[-1] == 0:
1384                    _empty_ranks.add(rank)
1385
1386            assert not empty_ranks or empty_ranks == _empty_ranks
1387            empty_ranks = _empty_ranks
1388            if state_name not in state_buffers:
1389                state_buffers[state_name] = [
1390                    None for _ in fsdp_param_info.param_indices
1391                ]
1392            local_state = input_states[fqn].get(state_name, None)
1393            # N.B. We need to move the state to compute_device. The reason is
1394            # not yet clear and we need to figure out why the state may be on a
1395            # different device.
1396            if local_state is not None:
1397                local_state = local_state.to(fsdp_param_info.state.compute_device)
1398            state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state
1399
1400        # Restoring the scalar and non-tensor states. If the corresponding
1401        # non-scalar states do not exist on the rank, we also skip the scalar
1402        # non-tensor states on that rank.
1403        for rank, object_state in enumerate(state_info):
1404            if rank in empty_ranks:
1405                continue
1406            for name, non_tensor_value in object_state.non_tensors.items():
1407                curr_non_tensor_value = gathered_state.get(name, None)
1408                assert (
1409                    curr_non_tensor_value is None
1410                    or curr_non_tensor_value == non_tensor_value
1411                ), (
1412                    f"Rank {rank} has different values for {name}: {non_tensor_value}."
1413                    + f" Other ranks: {curr_non_tensor_value}"
1414                )
1415                gathered_state[name] = non_tensor_value
1416
1417            for name, scalar_tensor_value in object_state.scalar_tensors.items():
1418                curr_scalar_tensor_value = gathered_state.get(name, None)
1419                assert curr_scalar_tensor_value is None or torch.equal(
1420                    scalar_tensor_value, curr_scalar_tensor_value
1421                ), (
1422                    f"Rank {rank} has different values for {name}: {scalar_tensor_value}."
1423                    + f" Other ranks: {curr_scalar_tensor_value}"
1424                )
1425                gathered_state[name] = scalar_tensor_value
1426
1427    return dtype, state_buffers  # type: ignore[possibly-undefined]
1428
1429
1430def _unflatten_orig_param_states(
1431    fsdp_param_info: FSDPParamInfo,
1432    output_states: Dict[str, Dict[str, Any]],
1433    state_name: str,
1434    shard_state: bool,
1435    to_save: bool,
1436    cpu_offload: bool,
1437) -> None:
1438    """
1439    Given a output state dict, ``output_states``, which the keys are FQNs to the
1440    original parameters (not FlatParameters nor parmeter ID), and the values
1441    are gathered states, unflatten the states to the original dimensions.
1442
1443    This function performs the unflattening process in-place.
1444    """
1445    if not to_save:
1446        return
1447    flat_param = fsdp_param_info.handle.flat_param
1448    fsdp_state = fsdp_param_info.state
1449    for fqn, gathered_state in output_states.items():
1450        value = gathered_state[state_name]
1451        param_idx = fsdp_param_info.param_indices[fqn]
1452
1453        # TODO: This solution is not general and only apply to PTD TP solution.
1454        if isinstance(value, DTensor):
1455            placement = value.placements[0]
1456            # If gathered state is a DTensor and its TP placement is not Replicate(), we need to
1457            # gather the tensor on its TP dimension before chunking them into DTensor again.
1458            if placement != Replicate():
1459                placement_dim = placement.dim  # type: ignore[attr-defined]
1460                value_local = value.redistribute(placements=(Replicate(),))
1461                reshape_size = list(flat_param._shapes[param_idx])
1462                reshape_size[placement_dim] *= value.device_mesh.size(0)
1463                reshape_size = torch.Size(reshape_size)
1464                value = value.reshape(reshape_size)
1465            # If gathered state is a replicate DTensor, we directly reshape it.
1466            else:
1467                value = value.reshape(flat_param._shapes[param_idx])
1468        else:
1469            # If gathered state is a tensor, we directly reshape it into unflatten state.
1470            value = value.reshape(flat_param._shapes[param_idx])
1471
1472        if shard_state:
1473            osd_config = fsdp_state._optim_state_dict_config
1474            if getattr(osd_config, "_use_dtensor", False):
1475                assert fsdp_state._device_mesh is not None
1476                value = _ext_chunk_dtensor(
1477                    value,
1478                    fsdp_state.rank,
1479                    fsdp_state._device_mesh,
1480                    fsdp_state._fsdp_extension,
1481                )
1482            else:
1483                assert fsdp_state.process_group is not None
1484                value = _ext_chunk_tensor(
1485                    value,
1486                    fsdp_state.rank,
1487                    fsdp_state.world_size,
1488                    fsdp_state._device_handle.device_count(),
1489                    fsdp_state.process_group,
1490                    fsdp_state._fsdp_extension,
1491                )
1492        elif not cpu_offload:
1493            with SimpleProfiler.profile("clone"):
1494                value = value.detach().clone()
1495
1496        if cpu_offload:
1497            with SimpleProfiler.profile(SimpleProfiler.Type.D2H):
1498                value = value.cpu()
1499        gathered_state[state_name] = value
1500
1501
1502def _allgather_orig_param_states(
1503    fsdp_param_info: FSDPParamInfo,
1504    gathered_state_info: List[Dict[str, StateInfo]],
1505    input_states: Dict[str, Any],
1506    shard_state: bool,
1507    to_save: bool,
1508    cpu_offload: bool,
1509) -> Dict[str, Dict[str, Any]]:
1510    """
1511    Given the ``gathered_state_info`` and ``input_states``, the API allgathers
1512    all tensor states and restore non-tensor states from ``gathered_state_info``.
1513    """
1514    fsdp_state = fsdp_param_info.state
1515    if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL:
1516        logger.info(
1517            "Memory Summary before calling to _allgather_orig_param_states %s",
1518            fsdp_state._device_handle.memory_summary(),
1519        )
1520
1521    output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()}
1522
1523    dtype, state_buffers = _convert_all_state_info(
1524        fsdp_param_info, gathered_state_info, input_states, output_states
1525    )
1526
1527    if len(state_buffers) == 0:
1528        return output_states
1529
1530    has_state_params: List[bool] = [
1531        True if fqn in output_states else False
1532        for fqn, idx in fsdp_param_info.param_indices.items()
1533    ]
1534
1535    # Loop through the ``state_buffers`` and construct the flattened, concatenated,
1536    # sharded states. The size of the constructed state will be the same size as
1537    # flat_param (also sharded).
1538    # Then we perform an allgather_into_tensor to get the full flat_param state.
1539    # The full flat_param state is the result of concatenation of multiple states
1540    # the order of of flat_param._fqns.
1541    # The final step is to split the flat_param state into original param states
1542    # and return the result.
1543    flat_param = fsdp_param_info.handle.flat_param
1544    empty_func = functools.partial(
1545        torch.empty, dtype=dtype, device=fsdp_state.compute_device
1546    )
1547    gathered_tensor = empty_func(flat_param._padded_unsharded_size)
1548    # Synchronize can be slow but this will be easier for us to debug.
1549    fsdp_state._device_handle.synchronize()
1550    for state_name, buffers in state_buffers.items():
1551        local_buffers: List[torch.Tensor] = []
1552        begin = fsdp_state.rank * flat_param._sharded_size.numel()
1553        # End is inclusive.
1554        end = begin + flat_param._sharded_size.numel() - 1
1555        # param_idx corresponds to the parameter index in the FlatParameter.
1556        mem_offset, param_idx = 0, 0
1557        for numel, is_padding in zip(
1558            flat_param._numels_with_padding, flat_param._is_padding_mask
1559        ):
1560            frozen_and_no_state = not is_padding and (
1561                not fsdp_param_info.param_requires_grad[param_idx]
1562                and not has_state_params[param_idx]
1563            )
1564
1565            if is_padding or frozen_and_no_state:
1566                # This memory range is a padding or the param is frozen and does
1567                # not require gradient. For the later case, we treat it as a
1568                # padding and add empty values to the local_buffers.
1569
1570                padding_begin, padding_end = mem_offset, mem_offset + numel - 1
1571                if padding_begin <= begin <= padding_end:
1572                    # The range is an align padding before the first parameter in
1573                    # the shard. The shard includes parts of this align padding.
1574                    padding_len = (
1575                        padding_end - begin + 1
1576                        if end >= padding_end
1577                        else end - begin + 1
1578                    )
1579                elif padding_begin <= end <= padding_end:
1580                    # The range is an align padding after the last parameter in
1581                    # the shard. The shard includes parts of this align padding.
1582                    padding_len = (
1583                        end - padding_begin + 1
1584                        if begin <= padding_begin
1585                        else end - begin + 1
1586                    )
1587                elif begin < padding_begin <= padding_end < end:
1588                    # The range is an align padding that is completely in the
1589                    # shard.
1590                    padding_len = numel
1591                else:
1592                    padding_len = 0
1593                if padding_len:
1594                    local_buffers.append(empty_func(padding_len))
1595
1596            if not is_padding:
1597                # This memory range is a parameter in FlatParameter. So there
1598                # should be an corresponding state in the optimizer unless the
1599                # parameter is frozen, which we treat it as a padding above.
1600
1601                # We need to check if this rank owns the buffer. If this is None:
1602                # 1.) the rank does not own any part of the original parameter.
1603                #     As a result, there is no corresponding optimizer state on
1604                #     the rank as well.
1605                # 2.) the parameter is frozen AND no optimizer state for the
1606                #     parameter. If a parameter is frozen, there can still be
1607                #     optimizer state if the parameter is not frozen in the
1608                #     previous steps.
1609                if buffers[param_idx] is not None:
1610                    local_buffers.append(cast(torch.Tensor, buffers[param_idx]))
1611                param_idx += 1
1612
1613            mem_offset += numel
1614
1615        shard_numel_padded = flat_param._sharded_size.numel() - (
1616            sum(t.numel() for t in local_buffers)
1617        )
1618
1619        assert flat_param._shard_numel_padded == shard_numel_padded, (
1620            "Manually calculated _sharded_numel_padded is incorrect. "
1621            f"_shard_numel_padded={flat_param._shard_numel_padded}, "
1622            f"shard_numel_padded={shard_numel_padded}, "
1623            f"_sharded_size.numel={flat_param._sharded_size.numel()}, "
1624            f"_numels_with_padding={flat_param._numels_with_padding}, "
1625            f"begin={begin}, end={end},"
1626        )
1627        if shard_numel_padded > 0:
1628            # Add right-handed padding.
1629            local_buffers.append(empty_func(shard_numel_padded))
1630        local_shard = torch.cat(local_buffers)
1631        assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), (
1632            "The size of local shard times the world size should equal to the "
1633            "gathered tensor size. The inconsistency may be from a bug of "
1634            "FlatParameter's metadata or the reconstruction logic in optimizer "
1635            "state dict."
1636        )
1637        fsdp_state._device_handle.synchronize()
1638        with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
1639            dist.all_gather_into_tensor(
1640                gathered_tensor, local_shard, group=fsdp_state.process_group
1641            )
1642            # Synchronize can be slow but this will be easier for us to debug.
1643            fsdp_state._device_handle.synchronize()
1644
1645        unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()]
1646        flat_param_handle = fsdp_param_info.handle
1647        orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor)
1648        assert len(orig_states) == len(fsdp_param_info.param_indices), (
1649            "The number of parameters from FlatParameter is not consistent to "
1650            "the number of states used by optimizer state dict reconstruction "
1651            "logic."
1652        )
1653        for fqn, idx in fsdp_param_info.param_indices.items():
1654            if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
1655                output_states[fqn][state_name] = orig_states[idx]
1656
1657        _unflatten_orig_param_states(
1658            fsdp_param_info,
1659            output_states,
1660            state_name,
1661            shard_state,
1662            to_save,
1663            cpu_offload,
1664        )
1665
1666    del gathered_tensor
1667    return output_states
1668
1669
1670def _gather_all_orig_param_state(
1671    fsdp_param_info: FSDPParamInfo,
1672    input_states: Dict[str, Any],
1673    shard_state: bool,
1674    to_save: bool,
1675    cpu_offload: bool,
1676) -> Dict[str, Any]:
1677    """
1678    Given a optimizer state dict, ``input_states``, which the keys are FQNs to the
1679    original parameters (not FlatParameters nor parmeter ID), gather all the
1680    states and unflatten them to the original dimensions. Note that all the
1681    params referred by the ``input_states`` must be managed by FSDP.
1682    """
1683    fsdp_state = fsdp_param_info.state
1684    if (
1685        fsdp_state.world_size == 1
1686        or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD
1687    ):
1688        return input_states if to_save else {}
1689
1690    with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING):
1691        with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ):
1692            gathered_state_info = _allgather_state_info(fsdp_state, input_states)
1693        output_states = _allgather_orig_param_states(
1694            fsdp_param_info,
1695            gathered_state_info,
1696            input_states,
1697            shard_state,
1698            to_save,
1699            cpu_offload,
1700        )
1701    if to_save:
1702        for key, idx in fsdp_param_info.param_indices.items():
1703            if key in output_states:
1704                continue
1705            if not fsdp_param_info.param_requires_grad[idx]:
1706                continue
1707
1708            raise RuntimeError(
1709                f"{key} is not in the output state. "
1710                "The FSDPParamInfo has the param keys "
1711                f"{sorted(fsdp_param_info.param_indices.keys())} while "
1712                "the output_states has the param keys "
1713                f"{sorted(output_states.keys())}."
1714            )
1715        return output_states
1716    else:
1717        return {}
1718
1719
1720def _convert_state_with_orig_params(
1721    all_optim_state_keys: List[_OptimStateKey],
1722    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
1723    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
1724    optim_state_dict: Dict[Union[str, int], Any],
1725    to_save: bool,
1726    shard_state: bool,
1727    cpu_offload: bool = True,
1728) -> Dict[str, Any]:
1729    fsdp_osd_state: Dict[str, Any] = {}
1730    # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo
1731    # usually corresponds to multiple parameters. We could not use FSDPParamInfo
1732    # as the key because FSDPParamInfo is not hashable. As a result, we fall back
1733    # to `id(FSDPParamInfo)`, which the type is an integer.
1734    all_states: Dict[int, Dict[str, Any]] = {}
1735    # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
1736    # across ranks
1737    for optim_state_key in all_optim_state_keys:
1738        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
1739            optim_state_key, None
1740        )
1741
1742        if param_key is None and not optim_state_key.is_fsdp_managed:
1743            continue
1744
1745        if optim_state_key.is_fsdp_managed:
1746            fqn = optim_state_key.unflat_param_names[0]
1747            fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None)
1748            if fsdp_param_info is None:
1749                # This can happen if the not all FSDP instances have all the
1750                # parameters. This can happen with FSDP + some MPMD style
1751                # parallelism.
1752
1753                # TODO: it is unclear if we need to do the same check with
1754                # non-FSDP managed keys.
1755                continue
1756            state = {} if param_key is None else optim_state_dict[param_key]
1757            if id(fsdp_param_info) not in all_states:
1758                all_states[id(fsdp_param_info)] = {}
1759            all_states[id(fsdp_param_info)][fqn] = state
1760
1761        elif to_save:
1762            assert len(optim_state_key.unflat_param_names) == 1
1763            unflat_param_name = optim_state_key.unflat_param_names[0]
1764            with SimpleProfiler.profile("none_fsdp_managed_copy"):
1765                param_key = cast(Union[str, int], param_key)
1766                fsdp_osd_state[unflat_param_name] = copy.copy(
1767                    optim_state_dict[param_key]
1768                )
1769                if cpu_offload:
1770                    for state_name, value in sorted_items(
1771                        fsdp_osd_state[unflat_param_name]
1772                    ):
1773                        if not torch.is_tensor(value):
1774                            continue
1775                        fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
1776
1777    # Instead of gathering the state of each parameter individually, we perform
1778    # the gathering  all at once to speed up the process.
1779    for _all_states in all_states.values():
1780        fqn = next(iter(_all_states.keys()))
1781        fsdp_param_info = fqn_to_fsdp_param_info[fqn]
1782        assert len(fsdp_param_info.param_requires_grad) > 0, (
1783            "With use_orig_params, FSDPParamInfo should have requires_grad "
1784            "information. However, the length is zero."
1785        )
1786        for key, idx in fsdp_param_info.param_indices.items():
1787            if key in _all_states:
1788                continue
1789            if not fsdp_param_info.param_requires_grad[idx]:
1790                continue
1791            raise RuntimeError(
1792                f"{key} is not in the optimizer state. "
1793                "The FSDPParamInfo has the param keys "
1794                f"{sorted(fsdp_param_info.param_indices.keys())} while "
1795                "the optimizer has the param keys "
1796                f"{sorted(_all_states.keys())}."
1797            )
1798        fsdp_osd_state.update(
1799            _gather_all_orig_param_state(
1800                fsdp_param_info,
1801                _all_states,
1802                shard_state,
1803                to_save,
1804                cpu_offload,
1805            )
1806        )
1807
1808    return fsdp_osd_state
1809
1810
1811def _convert_state_with_flat_params(
1812    all_optim_state_keys: List[_OptimStateKey],
1813    optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
1814    fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
1815    optim_state_dict: Dict[Union[str, int], Any],
1816    to_save: bool,
1817    shard_state: bool,
1818    cpu_offload: bool = True,
1819) -> Dict[str, Any]:
1820    fsdp_osd_state: Dict[str, Any] = {}
1821    # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
1822    # across ranks
1823    for optim_state_key in all_optim_state_keys:
1824        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
1825            optim_state_key, None
1826        )
1827
1828        assert param_key is not None, (
1829            "If use_orig_params is False, we must be able to find the "
1830            f"corresponding param id. {optim_state_key} {param_key}"
1831        )
1832
1833        if optim_state_key.is_fsdp_managed:
1834            # If there are multiple unflat_param_names (not use_orig_params),
1835            # they share the same FSDPParamInfo. So the first unflat_param_name
1836            # is sufficient to fetch the FSDPParamInfo.
1837            fqn = optim_state_key.unflat_param_names[0]
1838            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
1839            unflat_state = _unflatten_optim_state(
1840                fsdp_param_info,
1841                optim_state_dict[param_key],
1842                to_save,
1843                shard_state,
1844                cpu_offload,
1845            )
1846            if to_save:
1847                assert len(unflat_state) == len(optim_state_key.unflat_param_names)
1848                for unflat_param_name, unflat_param_state in zip(
1849                    optim_state_key.unflat_param_names,
1850                    unflat_state,
1851                ):
1852                    fsdp_osd_state[unflat_param_name] = unflat_param_state
1853        elif to_save:
1854            assert len(optim_state_key.unflat_param_names) == 1
1855            unflat_param_name = optim_state_key.unflat_param_names[0]
1856            fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
1857            if cpu_offload:
1858                for state_name, value in sorted_items(
1859                    fsdp_osd_state[unflat_param_name]
1860                ):
1861                    if not torch.is_tensor(value):
1862                        continue
1863                    fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
1864
1865    return fsdp_osd_state
1866
1867
1868@torch.no_grad()
1869def _optim_state_dict(
1870    model: nn.Module,
1871    optim: torch.optim.Optimizer,
1872    optim_state_dict: Dict[str, Any],
1873    optim_input: Optional[
1874        Union[
1875            List[Dict[str, Any]],
1876            Iterable[nn.Parameter],
1877        ]
1878    ],
1879    rank0_only: bool,
1880    shard_state: bool,
1881    group: Optional[dist.ProcessGroup],
1882    using_optim_input: bool,
1883    use_orig_params: bool = False,
1884    cpu_offload: bool = True,
1885) -> Dict[str, Any]:
1886    """
1887    Consolidates the optimizer state and returns it as a :class:`dict`
1888    following the convention of :meth:`torch.optim.Optimizer.state_dict`,
1889    i.e. with keys ``"state"`` and ``"param_groups"``.
1890    The flat parameters in ``FSDP`` modules contained in ``model`` are mapped
1891    back to their unflattened parameters.
1892
1893    Parameter keys are not well-defined. For a regular optimizer, the optimizer
1894    state_dict contains a mapping from parameter IDs to parameter states.
1895    Parameter IDs are the order of parameters in ``optim.param_groups()`` across
1896    all the groups. This API also allows user to pass ``optim_input`` for the
1897    mapping between parameters and parameter IDs. Using ``optim_input`` is being
1898    deprecated.
1899
1900    If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not
1901    contain parameter IDs mapping but a mapping from parameter FQNs to parameter
1902    states. This API finds the mapping from FQNs to parameters if the optimizer
1903    is a ``NamedOptimizer``.
1904
1905    If ``use_orig_params`` is True, each rank will have all FSDP-managed
1906    parameters but some of these parameters may be empty due to the sharding.
1907    For a regular optim.Optimizer, states for those empty parameters will
1908    not be initialized. So, when aggregating the FQNs across ranks, no assert
1909    will be raised on a rank even if it does not have all the states -- it is
1910    valid and FSDP knows how to aggregate them. However, FSDP has to ignore
1911    handling those parameters that are not managed by FSDP and do not exist on
1912    the local rank -- those are managed by other parallelisms and FSDP does not
1913    know how to handle/aggregate them.
1914
1915    Args:
1916        model (nn.Module): Root module (which may or may not be a
1917            :class:`FullyShardedDataParallel` instance) whose parameters
1918            were passed into the optimizer ``optim``.
1919        optim (torch.optim.Optimizer): Optimizer for ``model`` 's
1920            parameters.
1921        rank0_only (bool): If ``True``, saves the populated :class:`dict`
1922            only on rank 0; if ``False``, saves it on all ranks. (Default:
1923            ``True``)
1924        shard_state (bool): If ``True``, shard and distribute all
1925            non-zero-dimension states.
1926
1927    Returns:
1928        Dict[str, Any]: A :class:`dict` containing the optimizer state for
1929        ``model`` 's original unflattened parameters and including keys
1930        "state" and "param_groups" following the convention of
1931        :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``,
1932        then nonzero ranks return an empty :class:`dict`.
1933    """
1934    SimpleProfiler.reset()
1935    cm = ExitStack()
1936    cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL))
1937    _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model))
1938    to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state
1939
1940    with SimpleProfiler.profile("preprocessing"):
1941        param_to_fqns = _get_param_to_fqns(model)
1942        flat_param_to_fqn = _get_flat_param_to_fqn(model)
1943        is_named_optimizer = _is_named_optimizer(optim_state_dict)
1944
1945        param_key_to_param = cast(
1946            Dict[Union[int, str], nn.Parameter],
1947            (
1948                _get_param_id_to_param_from_optim_input(model, optim_input)
1949                if using_optim_input
1950                else _get_param_key_to_param(
1951                    optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn
1952                )
1953            ),
1954        )
1955        fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
1956
1957    with SimpleProfiler.profile("preprocessing_with_comm"):
1958        (
1959            all_optim_state_keys,
1960            optim_state_key_to_param_key,
1961        ) = _map_param_key_to_optim_keys(
1962            optim_state_dict,
1963            group,
1964            param_key_to_param,
1965            param_to_fqns,
1966            fqn_to_fsdp_param_info,
1967            merge_keys=use_orig_params,
1968        )
1969
1970    with SimpleProfiler.profile("state_converting"):
1971        convert_fn = (
1972            _convert_state_with_orig_params
1973            if use_orig_params
1974            else _convert_state_with_flat_params
1975        )
1976        fsdp_osd_state = convert_fn(
1977            all_optim_state_keys,
1978            optim_state_key_to_param_key,
1979            fqn_to_fsdp_param_info,
1980            optim_state_dict["state"],
1981            to_save,
1982            shard_state,
1983            cpu_offload,
1984        )
1985
1986    # At this point, communication is complete and ranks can return early if nothing
1987    # will be saved on that rank.
1988    if not to_save:
1989        return {}
1990
1991    fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state}
1992
1993    flat_param_fqns = set(flat_param_to_fqn.values())
1994    for key, value in optim_state_dict["state"].items():
1995        if key in fsdp_osd_state:
1996            continue
1997        if key in flat_param_fqns:
1998            continue
1999        if key in param_key_to_param:
2000            continue
2001        # This key is not recognized by FSDP. It may be a user-defined state
2002        # or some parameters state that FSDP is unable to map from
2003        # ``optim.param_groups``.
2004        warnings.warn(
2005            f"Found a optim state, {key}, that FSDP cannot process. FSDP "
2006            "will directly copy everything to the returned state_dict. In "
2007            "most cases, this is a user-defined state that is not "
2008            "associated with any particular parameter. Another possible "
2009            "case is this state is managed by TorchRec. Otherwise, there may "
2010            " be a mismatched assumption of optim_state_dict of this mode."
2011        )
2012        fsdp_osd_state[key] = value
2013
2014    if "param_groups" in optim_state_dict:
2015        fsdp_osd["param_groups"] = _unflatten_param_groups(
2016            optim_state_dict, param_key_to_param, param_to_fqns
2017        )
2018
2019    cm.close()
2020    SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ")
2021
2022    return fsdp_osd
2023
2024
2025def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
2026    """
2027    Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
2028    if the param is managed by FSDP. Shared parameters, or original parameters that
2029    are shared across multiple nn.Modules, are required to belong to one and only
2030    one FSDP instance and thus correspond to one ``FlatParameter``. Within the one
2031    ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared
2032    parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters.
2033    """
2034
2035    def module_fn(module, prefix, tree_level, fqn_to_param_info):
2036        fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
2037        if fsdp_state is None:
2038            return
2039        _lazy_init(fsdp_state, module)
2040        handle = _module_handle(fsdp_state, module)
2041        if not handle:
2042            return
2043        flat_param = handle.flat_param
2044        fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, [])
2045        # NOTE: `idx` indexes into the data structures *without* padding
2046        # elements
2047        for idx, local_fqn in enumerate(flat_param._fqns):
2048            fqn = clean_tensor_name(prefix + local_fqn)
2049            if fqn in fqn_to_param_info:
2050                assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
2051            fqn_to_param_info[fqn] = fsdp_param_info
2052            fsdp_param_info.param_indices[fqn] = idx
2053            if flat_param._params is not None:
2054                fsdp_param_info.param_requires_grad.append(
2055                    flat_param._params[idx].requires_grad
2056                )
2057
2058    def return_fn(fqn_to_param_info):
2059        return fqn_to_param_info
2060
2061    fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
2062    # FlatParameter._fqns stores the local fqn, starting from the root of the
2063    # FSDP. Using _apply_to_modules() with model (may not be the FSDP root
2064    # module) allows us to construct the global fqn.
2065    return _apply_to_modules(
2066        model,
2067        module_fn,
2068        return_fn,
2069        [fqn for fqn, _ in _named_parameters_with_duplicates(model)],
2070        fqn_to_param_info,
2071    )
2072
2073
2074@no_type_check
2075def _set_optim_use_dtensor(
2076    fsdp_state: _FSDPState,
2077    state_dict_settings: StateDictSettings,
2078) -> None:
2079    # If device_mesh is passed in when initalizing FSDP, we automatically turn the
2080    # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type
2081    # has to be set to SHARDED_STATE_DICT.
2082    if getattr(fsdp_state, "_device_mesh", None):
2083        state_dict_type = state_dict_settings.state_dict_type
2084        if state_dict_type == StateDictType.LOCAL_STATE_DICT:
2085            raise RuntimeError(
2086                "Found state_dict_type LOCAL_STATE_DICT.",
2087                "DeviceMesh is not compatible with LOCAL_STATE_DICT.",
2088                "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.",
2089            )
2090        else:
2091            state_dict_settings.optim_state_dict_config._use_dtensor = True
2092