xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import contextlib
4import copy
5import functools
6import math
7import traceback
8import warnings
9from contextlib import contextmanager
10from enum import auto, Enum
11from typing import (
12    Any,
13    Callable,
14    Dict,
15    Generator,
16    Iterable,
17    Iterator,
18    List,
19    Optional,
20    Tuple,
21    Union,
22)
23
24import torch
25import torch.distributed as dist
26import torch.distributed.fsdp._traversal_utils as traversal_utils
27import torch.nn as nn
28from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
29    _CHECKPOINT_WRAPPED_MODULE,
30    ActivationWrapper,
31)
32from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
33from torch.distributed.fsdp._common_utils import (
34    _FSDPState,
35    _get_param_to_fqns,
36    FSDP_PREFIX,
37    FSDP_WRAPPED_MODULE,
38    HandleTrainingState,
39    TrainingState,
40)
41from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
42from torch.distributed.fsdp._init_utils import (
43    _check_orig_params_flattened,
44    _init_buffer_state,
45    _init_core_state,
46    _init_device_handle,
47    _init_extension,
48    _init_ignored_module_states,
49    _init_param_handle_from_module,
50    _init_prefetching_state,
51    _init_process_group_state,
52    _init_runtime_state,
53    _init_state_dict_state,
54    HYBRID_SHARDING_STRATEGIES,
55    ProcessGroupType,
56)
57from torch.distributed.fsdp._runtime_utils import (
58    _get_fsdp_root_states,
59    _is_fsdp_root,
60    _lazy_init,
61    _post_forward,
62    _post_forward_reshard,
63    _pre_forward,
64    _pre_forward_unshard,
65    _root_pre_forward,
66    _unshard,
67    _wait_for_computation_stream,
68)
69from torch.distributed.fsdp._wrap_utils import _auto_wrap
70from torch.distributed.fsdp.api import (
71    BackwardPrefetch,
72    CPUOffload,
73    FullOptimStateDictConfig,
74    FullStateDictConfig,
75    LocalOptimStateDictConfig,
76    LocalStateDictConfig,
77    MixedPrecision,
78    OptimStateDictConfig,
79    ShardedOptimStateDictConfig,
80    ShardedStateDictConfig,
81    ShardingStrategy,
82    StateDictConfig,
83    StateDictSettings,
84    StateDictType,
85)
86from torch.distributed.tensor import DeviceMesh
87from torch.distributed.utils import _p_assert
88
89from ._flat_param import FlatParameter, FlatParamHandle
90from ._optim_utils import (
91    _flatten_optim_state_dict,
92    _get_param_id_to_param_from_optim_input,
93    _get_param_key_to_param,
94    _get_param_to_param_id_from_optim_input,
95    _get_param_to_param_key,
96    _optim_state_dict,
97    _rekey_sharded_optim_state_dict,
98    _set_optim_use_dtensor,
99)
100from ._state_dict_utils import _register_all_state_dict_hooks
101from ._unshard_param_utils import (
102    _deregister_orig_params,
103    _register_flat_param,
104    _register_orig_params,
105    _unshard_params,
106    _unshard_params_for_summon,
107)
108from .wrap import CustomPolicy, ModuleWrapPolicy
109
110
111__all__ = [
112    "FullyShardedDataParallel",
113    "OptimStateKeyType",
114]
115
116
117FLAT_PARAM = "_flat_param"
118
119
120class OptimStateKeyType(Enum):
121    """Represents the type of key in an optimizer state-dict."""
122
123    PARAM_NAME = auto()
124    PARAM_ID = auto()
125
126
127class FullyShardedDataParallel(nn.Module, _FSDPState):
128    """A wrapper for sharding module parameters across data parallel workers.
129
130    This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
131    FullyShardedDataParallel is commonly shortened to FSDP.
132
133    .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
134    .. _DeepSpeed: https://www.deepspeed.ai/
135
136    To understand FSDP internals, refer to the
137    :ref:`fsdp_notes`.
138
139    Example::
140
141        >>> # xdoctest: +SKIP("undefined variables")
142        >>> import torch
143        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
144        >>> torch.cuda.set_device(device_id)
145        >>> sharded_module = FSDP(my_module)
146        >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
147        >>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
148        >>> loss = x.sum()
149        >>> loss.backward()
150        >>> optim.step()
151
152    Using FSDP involves wrapping your module and then initializing your
153    optimizer after. This is required since FSDP changes the parameter
154    variables.
155
156    When setting up FSDP, you need to consider the destination CUDA
157    device. If the device has an ID (``dev_id``), you have three options:
158
159    * Place the module on that device
160    * Set the device using ``torch.cuda.set_device(dev_id)``
161    * Pass ``dev_id`` into the ``device_id`` constructor argument.
162
163    This ensures that the FSDP instance's compute device is the
164    destination device. For option 1 and 3, the FSDP initialization
165    always occurs on GPU. For option 2, the FSDP initialization
166    happens on module's current device, which may be a CPU.
167
168    If you're using the ``sync_module_states=True`` flag, you need to
169    ensure that the module is on a GPU or use the ``device_id``
170    argument to specify a CUDA device that FSDP will move the module
171    to in the FSDP constructor. This is necessary because
172    ``sync_module_states=True`` requires GPU communication.
173
174    FSDP also takes care of moving input tensors to the forward method
175    to the GPU compute device, so you don't need to manually move them
176    from CPU.
177
178    For ``use_orig_params=True``,
179    ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded
180    parameters, not the sharded parameters after forward, unlike
181    ``ShardingStrategy.FULL_SHARD``. If you want
182    to inspect the gradients, you can use the ``summon_full_params``
183    method with ``with_grads=True``.
184
185    With ``limit_all_gathers=True``, you may see a gap in the FSDP
186    pre-forward where the CPU thread is not issuing any kernels. This is
187    intentional and shows the rate limiter in effect. Synchronizing the CPU
188    thread in that way prevents over-allocating memory for subsequent
189    all-gathers, and it should not actually delay GPU kernel execution.
190
191    FSDP replaces managed modules' parameters with ``torch.Tensor``
192    views during forward and backward computation for autograd-related
193    reasons. If your module's forward relies on saved references to
194    the parameters instead of reacquiring the references each
195    iteration, then it will not see FSDP's newly created views,
196    and autograd will not work correctly.
197
198    Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD``
199    with the sharding process group being intra-node and the
200    replication process group being inter-node, setting
201    ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over
202    the replication process group for some cluster setups.
203
204    **Limitations**
205
206    There are several limitations to be aware of when using FSDP:
207
208    * FSDP currently does not support gradient accumulation outside
209      ``no_sync()`` when using CPU offloading. This is because FSDP
210      uses the newly-reduced gradient instead of accumulating with any
211      existing gradient, which can lead to incorrect results.
212
213    * FSDP does not support running the forward pass of a submodule
214      that is contained in an FSDP instance. This is because the
215      submodule's parameters will be sharded, but the submodule itself
216      is not an FSDP instance, so its forward pass will not all-gather
217      the full parameters appropriately.
218
219    * FSDP does not work with double backwards due to the way it
220      registers backward hooks.
221
222    * FSDP has some constraints when freezing parameters.
223      For ``use_orig_params=False``, each FSDP instance must manage
224      parameters that are all frozen or all non-frozen. For
225      ``use_orig_params=True``, FSDP supports mixing frozen and
226      non-frozen parameters, but it's recommended to avoid doing so to
227      prevent higher than expected gradient memory usage.
228
229    * As of PyTorch 1.12, FSDP offers limited support for shared
230      parameters. If enhanced shared parameter support is needed for
231      your use case, please post in
232      `this issue <https://github.com/pytorch/pytorch/issues/77724>`__.
233
234    * You should avoid modifying the parameters between forward and
235      backward without using the ``summon_full_params`` context, as
236      the modifications may not persist.
237
238    Args:
239        module (nn.Module):
240            This is the module to be wrapped with FSDP.
241        process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]):
242            This is the process group over which the model is sharded and thus
243            the one used for FSDP's all-gather and reduce-scatter collective
244            communications. If ``None``, then FSDP uses the default process
245            group. For hybrid sharding strategies such as
246            ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of
247            process groups, representing the groups over which to shard and
248            replicate, respectively. If ``None``, then FSDP constructs process
249            groups for the user to shard intra-node and replicate inter-node.
250            (Default: ``None``)
251        sharding_strategy (Optional[ShardingStrategy]):
252            This configures the sharding strategy, which may trade off memory
253            saving and communication overhead. See :class:`ShardingStrategy`
254            for details. (Default: ``FULL_SHARD``)
255        cpu_offload (Optional[CPUOffload]):
256            This configures CPU offloading. If this is set to ``None``, then
257            no CPU offloading happens. See :class:`CPUOffload` for details.
258            (Default: ``None``)
259        auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]):
260            This specifies a policy to apply FSDP to submodules of ``module``,
261            which is needed for communication and computation overlap and thus
262            affects performance. If ``None``, then FSDP only applies to
263            ``module``, and users should manually apply FSDP to parent modules
264            themselves (proceeding bottom-up). For convenience, this accepts
265            ``ModuleWrapPolicy`` directly, which allows users to specify the
266            module classes to wrap (e.g. the transformer block). Otherwise,
267            this should be a callable that takes in three arguments
268            ``module: nn.Module``, ``recurse: bool``, and
269            ``nonwrapped_numel: int`` and should return a ``bool`` specifying
270            whether the passed-in ``module`` should have FSDP applied if
271            ``recurse=False`` or if the traversal should continue into the
272            module's subtree if ``recurse=True``. Users may add additional
273            arguments to the callable. The ``size_based_auto_wrap_policy`` in
274            ``torch.distributed.fsdp.wrap.py`` gives an example callable that
275            applies FSDP to a module if the parameters in its subtree exceed
276            100M numel. We recommend printing the model after applying FSDP
277            and adjusting as needed.
278
279            Example::
280
281                >>> def custom_auto_wrap_policy(
282                >>>     module: nn.Module,
283                >>>     recurse: bool,
284                >>>     nonwrapped_numel: int,
285                >>>     # Additional custom arguments
286                >>>     min_num_params: int = int(1e8),
287                >>> ) -> bool:
288                >>>     return nonwrapped_numel >= min_num_params
289                >>> # Configure a custom `min_num_params`
290                >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
291
292        backward_prefetch (Optional[BackwardPrefetch]):
293            This configures explicit backward prefetching of all-gathers. If
294            ``None``, then FSDP does not backward prefetch, and there is no
295            communication and computation overlap in the backward pass. See
296            :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``)
297        mixed_precision (Optional[MixedPrecision]):
298            This configures native mixed precision for FSDP. If this is set to
299            ``None``, then no mixed precision is used. Otherwise, parameter,
300            buffer, and gradient reduction dtypes can be set. See
301            :class:`MixedPrecision` for details. (Default: ``None``)
302        ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
303            own parameters and child modules' parameters and buffers are
304            ignored by this instance. None of the modules directly in
305            ``ignored_modules`` should be :class:`FullyShardedDataParallel`
306            instances, and any child modules that are already-constructed
307            :class:`FullyShardedDataParallel` instances will not be ignored if
308            they are nested under this instance. This argument may be used to
309            avoid sharding specific parameters at module granularity when using an
310            ``auto_wrap_policy`` or if parameters' sharding is not managed by
311            FSDP. (Default: ``None``)
312        param_init_fn (Optional[Callable[[nn.Module], None]]):
313            A ``Callable[torch.nn.Module] -> None`` that
314            specifies how modules that are currently on the meta device should
315            be initialized onto an actual device. As of v1.12, FSDP detects
316            modules with parameters or buffers on meta device via ``is_meta``
317            and either applies ``param_init_fn`` if specified or calls
318            ``nn.Module.reset_parameters()`` otherwise. For both cases, the
319            implementation should *only* initialize the parameters/buffers of
320            the module, not those of its submodules. This is to avoid
321            re-initialization. In addition, FSDP also supports deferred
322            initialization via torchdistX's (https://github.com/pytorch/torchdistX)
323            ``deferred_init()`` API, where the deferred modules are initialized
324            by calling ``param_init_fn`` if specified or torchdistX's default
325            ``materialize_module()`` otherwise. If ``param_init_fn`` is
326            specified, then it is applied to all meta-device modules, meaning
327            that it should probably case on the module type. FSDP calls the
328            initialization function before parameter flattening and sharding.
329
330            Example::
331
332                >>> # xdoctest: +SKIP("undefined variables")
333                >>> module = MyModule(device="meta")
334                >>> def my_init_fn(module: nn.Module):
335                >>>     # E.g. initialize depending on the module type
336                >>>     ...
337                >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
338                >>> print(next(fsdp_model.parameters()).device) # current CUDA device
339                >>> # With torchdistX
340                >>> module = deferred_init.deferred_init(MyModule, device="cuda")
341                >>> # Will initialize via deferred_init.materialize_module().
342                >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
343
344        device_id (Optional[Union[int, torch.device]]): An ``int`` or
345            ``torch.device`` giving the CUDA device on which FSDP
346            initialization takes place, including the module initialization
347            if needed and the parameter sharding. This should be specified to
348            improve initialization speed if ``module`` is on CPU. If the
349            default CUDA device was set (e.g. via ``torch.cuda.set_device``),
350            then the user may pass ``torch.cuda.current_device`` to this.
351            (Default: ``None``)
352        sync_module_states (bool): If ``True``, then each FSDP module will
353            broadcast module parameters and buffers from rank 0 to ensure that
354            they are replicated across ranks (adding communication overhead to
355            this constructor). This can help load ``state_dict`` checkpoints
356            via ``load_state_dict`` in a memory efficient way. See
357            :class:`FullStateDictConfig` for an example of this. (Default:
358            ``False``)
359        forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches
360            the next forward-pass all-gather before the current forward
361            computation. This is only useful for CPU-bound workloads, in which
362            case issuing the next all-gather earlier may improve overlap. This
363            should only be used for static-graph models since the prefetching
364            follows the first iteration's execution order. (Default: ``False``)
365        limit_all_gathers (bool): If ``True``, then FSDP explicitly
366            synchronizes the CPU thread to ensure GPU memory usage from only
367            *two* consecutive FSDP instances (the current instance running
368            computation and the next instance whose all-gather is prefetched).
369            If ``False``, then FSDP allows the CPU thread to issue all-gathers
370            without any extra synchronization. (Default: ``True``) We often
371            refer to this feature as the "rate limiter". This flag should only
372            be set to ``False`` for specific CPU-bound workloads with low
373            memory pressure in which case the CPU thread can aggressively issue
374            all kernels without concern for the GPU memory usage.
375        use_orig_params (bool): Setting this to ``True`` has FSDP use
376            ``module`` 's original parameters. FSDP exposes those original
377            parameters to the user via :meth:`nn.Module.named_parameters`
378            instead of FSDP's internal :class:`FlatParameter` s. This means
379            that the optimizer step runs on the original parameters, enabling
380            per-original-parameter hyperparameters. FSDP preserves the original
381            parameter variables and manipulates their data between unsharded
382            and sharded forms, where they are always views into the underlying
383            unsharded or sharded :class:`FlatParameter`, respectively. With the
384            current algorithm, the sharded form is always 1D, losing the
385            original tensor structure. An original parameter may have all,
386            some, or none of its data present for a given rank. In the none
387            case, its data will be like a size-0 empty tensor. Users should not
388            author programs relying on what data is present for a given
389            original parameter in its sharded form. ``True`` is required to
390            use ``torch.compile()``. Setting this to ``False`` exposes FSDP's
391            internal :class:`FlatParameter` s to the user via
392            :meth:`nn.Module.named_parameters`. (Default: ``False``)
393        ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]):
394            Ignored parameters or modules that will not be managed by this FSDP
395            instance, meaning that the parameters are not sharded and their
396            gradients are not reduced across ranks. This argument unifies with
397            the existing ``ignored_modules`` argument, and we may deprecate
398            ``ignored_modules`` soon. For backward compatibility, we keep both
399            ``ignored_states`` and `ignored_modules``, but FSDP only allows one
400            of them to be specified as not ``None``.
401        device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an altenative to
402            process_group. When device_mesh is passed, FSDP will use the underlying process
403            groups for all-gather and reduce-scatter collective communications. Therefore,
404            these two args need to be mutually exclusive. For hybrid sharding strategies such as
405            ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead
406            of a tuple of process groups. For 2D FSDP + TP, users are required to pass in
407            device_mesh instead of process_group. For more DeviceMesh info, please visit:
408            https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
409    """
410
411    def __init__(
412        self,
413        module: nn.Module,
414        process_group: ProcessGroupType = None,
415        sharding_strategy: Optional[ShardingStrategy] = None,
416        cpu_offload: Optional[CPUOffload] = None,
417        auto_wrap_policy: Optional[
418            Union[Callable, ModuleWrapPolicy, CustomPolicy]
419        ] = None,
420        backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
421        mixed_precision: Optional[MixedPrecision] = None,
422        ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
423        param_init_fn: Optional[Callable[[nn.Module], None]] = None,
424        device_id: Optional[Union[int, torch.device]] = None,
425        sync_module_states: bool = False,
426        forward_prefetch: bool = False,
427        limit_all_gathers: bool = True,
428        use_orig_params: bool = False,
429        ignored_states: Union[
430            Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
431        ] = None,
432        device_mesh: Optional[DeviceMesh] = None,
433    ):
434        torch._C._log_api_usage_once("torch.distributed.fsdp")
435        super().__init__()
436        if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
437            warnings.warn(
438                "FSDP will not all-gather parameters for containers that do "
439                f"not implement forward: {module}",
440                stacklevel=2,
441            )
442        _init_ignored_module_states(self, module, ignored_modules, ignored_states)
443        _init_device_handle(self, module, self._ignored_params, device_id)
444
445        # Add module annotations for Dynamo support (see function for details)
446        _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
447
448        # Initializes self.process_group, along with rank and world size. This will
449        # also set another attribute, _inter_node_pg, to control the process group
450        # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}.
451        # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up
452        # the same process group state as the root FSDP module.
453        self._device_mesh = device_mesh
454        _init_process_group_state(
455            self,
456            process_group,
457            sharding_strategy,
458            auto_wrap_policy,
459            device_mesh,
460        )
461        if auto_wrap_policy is not None:
462            root_kwargs = {
463                "process_group": process_group,
464                "sharding_strategy": sharding_strategy,
465                "cpu_offload": cpu_offload,
466                "backward_prefetch": backward_prefetch,
467                "mixed_precision": mixed_precision,
468                "param_init_fn": param_init_fn,
469                "device_id": device_id,
470                "sync_module_states": sync_module_states,
471                "forward_prefetch": forward_prefetch,
472                "limit_all_gathers": limit_all_gathers,
473                "use_orig_params": use_orig_params,
474                "ignored_states": self._ignored_params,
475                "device_mesh": device_mesh,
476            }
477            if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None:
478                # Share root process groups with children to maintain
479                # the invariant that all FSDP modules will have the same
480                # process groups.
481                root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)
482
483            _auto_wrap(
484                module,
485                auto_wrap_policy,
486                self._ignored_modules,
487                self._ignored_params,
488                root_kwargs,
489                FullyShardedDataParallel,
490            )
491
492        backward_prefetch_limit = 1
493        forward_prefetch_limit = 1
494        _init_core_state(
495            self,
496            sharding_strategy,
497            mixed_precision,
498            cpu_offload,
499            limit_all_gathers,
500            use_orig_params,
501            backward_prefetch_limit,
502            forward_prefetch_limit,
503        )
504        _init_runtime_state(self)
505        _init_prefetching_state(self, backward_prefetch, forward_prefetch)
506        _init_buffer_state(self, module)
507        # extension needs to be set before `_init_param_handle_from_module()`
508        _init_extension(self, device_mesh)
509        _init_param_handle_from_module(
510            self,
511            module,
512            device_id,
513            param_init_fn,
514            sync_module_states,
515        )
516        self._fsdp_wrapped_module = module
517        if not use_orig_params:
518            _check_orig_params_flattened(self, self._ignored_params)
519            _register_flat_param(self, self)
520
521        # `_state_dict_type` controls the `state_dict()` behavior, which is
522        # implemented using post-save and pre-load hooks
523        _init_state_dict_state(self)
524        _register_all_state_dict_hooks(self)
525        self._zero_scalar = None
526
527    @property
528    def module(self) -> nn.Module:
529        """Return the wrapped module."""
530        # FSDP's `.module` must refer to the innermost wrapped module when
531        # composing with other module wrappers in order for state dict to work
532        if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
533            return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
534        return self._fsdp_wrapped_module
535
536    @property
537    def _has_params(self) -> bool:
538        """Returns whether this FSDP instance manages any parameters."""
539        return hasattr(self, "_handle") and self._handle is not None
540
541    @property
542    def _flat_param(self) -> Optional[FlatParameter]:
543        return self._handle.flat_param if self._handle else None
544
545    def __getattr__(self, name: str) -> Any:
546        """Forward missing attributes to the wrapped module."""
547        try:
548            return super().__getattr__(name)  # defer to nn.Module's logic
549        except AttributeError:
550            return getattr(self._fsdp_wrapped_module, name)
551
552    def __getitem__(self, key: int) -> Any:
553        """Forward indexing calls in case the module is an ``nn.Sequential``."""
554        if hasattr(self, FSDP_WRAPPED_MODULE):
555            return self._fsdp_wrapped_module.__getitem__(key)  # type: ignore[operator]
556        return super().__getitem__(key)
557
558    def check_is_root(self) -> bool:
559        """Check if this instance is a root FSDP module."""
560        return _is_fsdp_root(self, self)
561
562    @staticmethod
563    def fsdp_modules(
564        module: nn.Module,
565        root_only: bool = False,
566    ) -> List["FullyShardedDataParallel"]:
567        """Return all nested FSDP instances.
568
569        This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``.
570
571        Args:
572            module (torch.nn.Module): Root module, which may or may not be an
573                ``FSDP`` module.
574            root_only (bool): Whether to return only FSDP root modules.
575                (Default: ``False``)
576
577        Returns:
578            List[FullyShardedDataParallel]: FSDP modules that are nested in
579            the input ``module``.
580        """
581        if root_only:
582            return _get_fsdp_root_states(module)
583        return traversal_utils._get_fsdp_states(module)
584
585    def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
586        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
587
588        Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).
589
590        Compared to ``torch.nn.Module.apply``, this version additionally gathers
591        the full parameters before applying ``fn``. It should not be called from
592        within another ``summon_full_params`` context.
593
594        Args:
595            fn (:class:`Module` -> None): function to be applied to each submodule
596
597        Returns:
598            Module: self
599        """
600        uninitialized = self._is_root is None
601        self._assert_state(TrainingState.IDLE)
602        # Use `_unshard_params_for_summon()` with `recurse=False` instead of
603        # `_unshard_fsdp_state_params()` directly to perform lazy
604        # initialization, which is needed to initialize `FlatParameter`
605        # parameter attributes as required by the unshard logic
606        with _unshard_params_for_summon(
607            self,
608            self,
609            writeback=True,
610            rank0_only=False,
611            offload_to_cpu=False,
612            with_grads=False,
613        ):
614            ret = super().apply(fn)
615
616        # Reset lazy init called in `_unshard_params_for_summon()` since
617        # `apply()` may have been called on FSDP instance that is not truly a
618        # root, in which case it will be incorrectly marked as one.
619        if uninitialized and self._is_root:
620            for module in traversal_utils._get_fsdp_states(self):
621                module._reset_lazy_init()
622
623        return ret
624
625    def _mixed_precision_enabled_for_buffers(self) -> bool:
626        """Return whether the user explicitly enabled buffer mixed precision.
627
628        NOTE: Unlike parameters and gradient reduction, buffer mixed precision
629        is applied at the FSDP instance level, not the ``FlatParameter`` level,
630        which may be different for the composable code path.
631        """
632        return self.mixed_precision.buffer_dtype is not None
633
634    def _low_precision_hook_enabled(self) -> bool:
635        """Whether a low precision hook is registered or not."""
636        return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS
637
638    def _reset_lazy_init(self) -> None:
639        """Reset instance so :func:`_lazy_init` will run on the next forward."""
640        self._is_root: Optional[bool] = None
641
642    @staticmethod
643    def set_state_dict_type(
644        module: nn.Module,
645        state_dict_type: StateDictType,
646        state_dict_config: Optional[StateDictConfig] = None,
647        optim_state_dict_config: Optional[OptimStateDictConfig] = None,
648    ) -> StateDictSettings:
649        """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
650
651        Also takes (optional) configuration for the model's and optimizer's state dict.
652        The target module does not have to be a FSDP module. If the target
653        module is a FSDP module, its ``state_dict_type`` will also be changed.
654
655        .. note:: This API should be called for only the top-level (root)
656            module.
657
658        .. note:: This API enables users to transparently use the conventional
659            ``state_dict`` API to take model checkpoints in cases where the
660            root FSDP module is wrapped by another ``nn.Module``. For example,
661            the following will ensure ``state_dict`` is called on all non-FSDP
662            instances, while dispatching into `sharded_state_dict` implementation
663            for FSDP:
664
665        Example::
666
667            >>> # xdoctest: +SKIP("undefined variables")
668            >>> model = DDP(FSDP(...))
669            >>> FSDP.set_state_dict_type(
670            >>>     model,
671            >>>     StateDictType.SHARDED_STATE_DICT,
672            >>>     state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
673            >>>     optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
674            >>> )
675            >>> param_state_dict = model.state_dict()
676            >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
677
678        Args:
679            module (torch.nn.Module): Root module.
680            state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
681            state_dict_config (Optional[StateDictConfig]): the configuration for the
682                target ``state_dict_type``.
683            optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration
684                for the optimizer state dict.
685
686        Returns:
687            A StateDictSettings that include the previous state_dict type and
688            configuration for the module.
689        """
690        warnings.warn(
691            "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being "
692            "deprecated. Please use APIs, get_state_dict() and set_state_dict(), "
693            "which can support different parallelisms, FSDP1, FSDP2, DDP. "
694            "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html"
695            "#torch.distributed.checkpoint.state_dict.get_state_dict ."
696            "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .",
697            FutureWarning,
698        )
699        _state_dict_type_to_config = {
700            StateDictType.FULL_STATE_DICT: FullStateDictConfig,
701            StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
702            StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
703        }
704        _optim_state_dict_type_to_config = {
705            StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig,
706            StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig,
707            StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig,
708        }
709
710        # Use the default config if a state_dict config is not set.
711        state_dict_config_type = _state_dict_type_to_config[state_dict_type]
712        optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type]
713        if state_dict_config is None:
714            state_dict_config = state_dict_config_type()
715        if optim_state_dict_config is None:
716            optim_state_dict_config = optim_state_dict_config_type()
717        if state_dict_config_type != type(state_dict_config):
718            raise RuntimeError(
719                f"Expected state_dict_config of type {state_dict_config_type} "
720                f"but got {type(state_dict_config)}"
721            )
722        if optim_state_dict_config_type != type(optim_state_dict_config):
723            raise RuntimeError(
724                f"Expected optim_state_dict_config of type {optim_state_dict_config_type} "
725                f"but got {type(optim_state_dict_config)}"
726            )
727
728        # Set the state_dict type and configurations.
729        prev_state_dict_type = None
730        prev_state_dict_config = None
731        prev_optim_state_dict_config = None
732        for submodule in traversal_utils._get_fsdp_states(module):
733            if prev_state_dict_type is None:
734                prev_state_dict_type = submodule._state_dict_type
735            else:
736                assert (
737                    prev_state_dict_type == submodule._state_dict_type
738                ), "All FSDP modules should have the same state_dict_type."
739            if prev_state_dict_config is None:
740                prev_state_dict_config = submodule._state_dict_config
741            else:
742                assert isinstance(
743                    submodule._state_dict_config, type(prev_state_dict_config)
744                ), "All FSDP modules must have the same type of state_dict_config."
745            if prev_optim_state_dict_config is None:
746                prev_optim_state_dict_config = submodule._optim_state_dict_config
747            else:
748                assert isinstance(
749                    submodule._optim_state_dict_config,
750                    type(prev_optim_state_dict_config),
751                ), "All FSDP modules must have the same type of optim_state_dict_config."
752
753            submodule._state_dict_type = state_dict_type
754            submodule._state_dict_config = state_dict_config
755            submodule._optim_state_dict_config = optim_state_dict_config
756
757        return StateDictSettings(
758            prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config
759        )
760
761    @staticmethod
762    def get_state_dict_type(module: nn.Module) -> StateDictSettings:
763        """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``.
764
765        The target module does not have to be an FSDP module.
766
767        Returns:
768            A ``StateDictSettings`` containing the state_dict_type and
769            state_dict / optim_state_dict configs that are currently set.
770
771        Raises:
772            ``AssertionError`` if the ``StateDictSettings`` for different
773            FSDP submodules differ.
774        """
775        state_dict_settings: Optional[StateDictSettings] = None
776        for submodule in FullyShardedDataParallel.fsdp_modules(module):
777            if state_dict_settings is None:
778                state_dict_settings = StateDictSettings(
779                    state_dict_type=submodule._state_dict_type,
780                    state_dict_config=submodule._state_dict_config,
781                    optim_state_dict_config=submodule._optim_state_dict_config,
782                )
783                _set_optim_use_dtensor(submodule, state_dict_settings)
784            else:
785                submodule_settings = StateDictSettings(
786                    submodule._state_dict_type,
787                    submodule._state_dict_config,
788                    submodule._optim_state_dict_config,
789                )
790                assert state_dict_settings == submodule_settings, (
791                    "All FSDP modules must have the same state dict settings."
792                    f"Got {submodule_settings} and {state_dict_settings}."
793                )
794                _set_optim_use_dtensor(submodule, submodule_settings)
795        return state_dict_settings
796
797    @staticmethod
798    @contextlib.contextmanager
799    def state_dict_type(
800        module: nn.Module,
801        state_dict_type: StateDictType,
802        state_dict_config: Optional[StateDictConfig] = None,
803        optim_state_dict_config: Optional[OptimStateDictConfig] = None,
804    ) -> Generator:
805        """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
806
807        This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
808        :meth:`set_state_dict_type` for the detail.
809
810        Example::
811
812            >>> # xdoctest: +SKIP("undefined variables")
813            >>> model = DDP(FSDP(...))
814            >>> with FSDP.state_dict_type(
815            >>>     model,
816            >>>     StateDictType.SHARDED_STATE_DICT,
817            >>> ):
818            >>>     checkpoint = model.state_dict()
819
820        Args:
821            module (torch.nn.Module): Root module.
822            state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
823            state_dict_config (Optional[StateDictConfig]): the model ``state_dict``
824                configuration for the target ``state_dict_type``.
825            optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer
826               ``state_dict`` configuration for the target ``state_dict_type``.
827        """
828        prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
829            module,
830            state_dict_type,
831            state_dict_config,
832            optim_state_dict_config,
833        )
834        yield
835        FullyShardedDataParallel.set_state_dict_type(
836            module,
837            prev_state_dict_settings.state_dict_type,
838            prev_state_dict_settings.state_dict_config,
839            prev_state_dict_settings.optim_state_dict_config,
840        )
841
842    def forward(self, *args: Any, **kwargs: Any) -> Any:
843        """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
844        handle = self._handle
845        with torch.autograd.profiler.record_function(
846            "FullyShardedDataParallel.forward"
847        ):
848            args, kwargs = _root_pre_forward(self, self, args, kwargs)
849            unused = None
850            args, kwargs = _pre_forward(
851                self,
852                handle,
853                _pre_forward_unshard,
854                self._fsdp_wrapped_module,
855                args,
856                kwargs,
857            )
858            if handle:
859                _p_assert(
860                    handle.flat_param.device == self.compute_device,
861                    "Expected `FlatParameter` to be on the compute device "
862                    f"{self.compute_device} but got {handle.flat_param.device}",
863                )
864            output = self._fsdp_wrapped_module(*args, **kwargs)
865            return _post_forward(
866                self, handle, _post_forward_reshard, self, unused, output
867            )
868
869    @staticmethod
870    @contextlib.contextmanager
871    def summon_full_params(
872        module: nn.Module,
873        recurse: bool = True,
874        writeback: bool = True,
875        rank0_only: bool = False,
876        offload_to_cpu: bool = False,
877        with_grads: bool = False,
878    ) -> Generator:
879        r"""Expose full params for FSDP instances with this context manager.
880
881        Can be useful *after* forward/backward for a model to get
882        the params for additional processing or checking. It can take a non-FSDP
883        module and will summon full params for all contained FSDP modules as
884        well as their children, depending on the ``recurse`` argument.
885
886        .. note:: This can be used on inner FSDPs.
887        .. note:: This can *not* be used within a forward or backward pass. Nor
888            can forward and backward be started from within this context.
889        .. note:: Parameters will revert to their local shards after the context
890            manager exits, storage behavior is the same as forward.
891        .. note:: The full parameters can be modified, but only the portion
892            corresponding to the local param shard will persist after the
893            context manager exits (unless ``writeback=False``, in which case
894            changes will be discarded). In the case where FSDP does not shard
895            the parameters, currently only when ``world_size == 1``, or ``NO_SHARD``
896            config, the modification is persisted regardless of ``writeback``.
897        .. note:: This method works on modules which are not FSDP themselves but
898            may contain multiple independent FSDP units. In that case, the given
899            arguments will apply to all contained FSDP units.
900
901        .. warning:: Note that ``rank0_only=True`` in conjunction with
902            ``writeback=True`` is not currently supported and will raise an
903            error. This is because model parameter shapes would be different
904            across ranks within the context, and writing to them can lead to
905            inconsistency across ranks when the context is exited.
906
907        .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will
908            result in full parameters being redundantly copied to CPU memory for
909            GPUs that reside on the same machine, which may incur the risk of
910            CPU OOM. It is recommended to use ``offload_to_cpu`` with
911            ``rank0_only=True``.
912
913        Args:
914            recurse (bool, Optional): recursively summon all params for nested
915                FSDP instances (default: True).
916            writeback (bool, Optional): if ``False``, modifications to params are
917                discarded after the context manager exits;
918                disabling this can be slightly more efficient (default: True)
919            rank0_only (bool, Optional): if ``True``, full parameters are
920                materialized on only global rank 0. This means that within the
921                context, only rank 0 will have full parameters and the other
922                ranks will have sharded parameters. Note that setting
923                ``rank0_only=True`` with ``writeback=True`` is not supported,
924                as model parameter shapes will be different across ranks
925                within the context, and writing to them can lead to
926                inconsistency across ranks when the context is exited.
927            offload_to_cpu (bool, Optional): If ``True``, full parameters are
928                offloaded to CPU. Note that this offloading currently only
929                occurs if the parameter is sharded (which is only not the case
930                for world_size = 1 or ``NO_SHARD`` config). It is recommended
931                to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid
932                redundant copies of model parameters being offloaded to the same CPU memory.
933            with_grads (bool, Optional): If ``True``, gradients are also
934                unsharded with the parameters. Currently, this is only
935                supported when passing ``use_orig_params=True`` to the FSDP
936                constructor and ``offload_to_cpu=False`` to this method.
937                (Default: ``False``)
938        """
939        with _unshard_params(
940            module, recurse, writeback, rank0_only, offload_to_cpu, with_grads
941        ):
942            yield
943
944    @contextlib.contextmanager
945    def _deregister_orig_params_ctx(self):
946        """Deregister the original parameters and expose the :class:`FlatParameter`.
947
948        If a :class:`FlatParameter` is sharded, then
949        this refreshes the sharded views before exiting. This method should
950        only be called when using the original parameters.
951        """
952        _p_assert(
953            self._use_orig_params,
954            "`_deregister_orig_params_ctx()` should only be called when "
955            "`_use_orig_params=True`",
956        )
957        for fsdp_module in traversal_utils._get_fsdp_states(self):
958            _deregister_orig_params(fsdp_module, fsdp_module)
959        try:
960            yield
961        finally:
962            for fsdp_module in traversal_utils._get_fsdp_states(self):
963                _register_orig_params(fsdp_module, fsdp_module)
964
965    def _apply(self, *args, **kwargs):
966        """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``."""
967        # When using the original parameters: Since (1) the `FlatParameter`s
968        # own the storage and (2) `_apply()` is the subroutine underlying the
969        # most common storage-changing ops like `to()` and `cuda()`, we
970        # override `_apply()` to have the storage change directly performed on
971        # the `FlatParameter`s instead of applying to the original parameters
972        # and then writing back to the `FlatParameter`s.
973        context = (
974            self._deregister_orig_params_ctx()
975            if self._use_orig_params
976            else contextlib.nullcontext()
977        )
978        with context:
979            return super()._apply(*args, **kwargs)
980
981    def named_buffers(
982        self,
983        *args,
984        **kwargs,
985    ) -> Iterator[Tuple[str, torch.Tensor]]:
986        """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
987
988        Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix
989        when inside the :meth:`summon_full_params` context manager.
990        """
991        should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
992        for buffer_name, buffer in super().named_buffers(*args, **kwargs):
993            if should_clean_name:
994                # Remove any instances of the FSDP-specific prefix; there can
995                # be multiple in the case of nested FSDP modules
996                buffer_name = buffer_name.replace(FSDP_PREFIX, "")
997            yield (buffer_name, buffer)
998
999    def named_parameters(
1000        self,
1001        *args,
1002        **kwargs,
1003    ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
1004        """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself.
1005
1006        Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix
1007        when inside the :meth:`summon_full_params` context manager.
1008        """
1009        should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
1010        for param_name, param in super().named_parameters(*args, **kwargs):
1011            if should_clean_name:
1012                # Remove any instances of the FSDP-specific prefix; there can
1013                # be multiple in the case of nested FSDP modules
1014                param_name = param_name.replace(FSDP_PREFIX, "")
1015            yield (param_name, param)
1016
1017    def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
1018        """Assert we are in the given state."""
1019        # Since assert can be turned off and this error checking
1020        # is really important, we use explicit error checking
1021        # and raise a ValueError if needed.
1022        if isinstance(state, TrainingState):
1023            state = [state]
1024        if self.training_state not in state:
1025            msg = (
1026                f"expected to be in states {state} but current state "
1027                f"is {self.training_state}"
1028            )
1029            # In case we are failing in the context of autograd hook, asserting
1030            # may not generate useful msg. So, let's print it to be sure.
1031            if self.rank == 0:
1032                print(f"Asserting FSDP instance is: {self}")
1033                print(f"ERROR: {msg}")
1034                traceback.print_stack()
1035            raise ValueError(msg)
1036
1037    @contextmanager
1038    def no_sync(self) -> Generator:
1039        """Disable gradient synchronizations across FSDP instances.
1040
1041        Within this context, gradients will be accumulated in module
1042        variables, which will later be synchronized in the first
1043        forward-backward pass after exiting the context. This should only be
1044        used on the root FSDP instance and will recursively apply to all
1045        children FSDP instances.
1046
1047        .. note:: This likely results in higher memory usage because FSDP will
1048            accumulate the full model gradients (instead of gradient shards)
1049            until the eventual sync.
1050
1051        .. note:: When used with CPU offloading, the gradients will not be
1052            offloaded to CPU when inside the context manager. Instead, they
1053            will only be offloaded right after the eventual sync.
1054        """
1055        _lazy_init(self, self)
1056        if not self._is_root:
1057            raise RuntimeError(
1058                "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module."
1059            )
1060        self._assert_state(TrainingState.IDLE)
1061        old_flags = []
1062        for m in self.modules():
1063            if isinstance(m, FullyShardedDataParallel):
1064                old_flags.append((m, m._sync_gradients))
1065                m._sync_gradients = False
1066        try:
1067            yield
1068        finally:
1069            for m, old_flag in old_flags:
1070                assert not m._sync_gradients, (
1071                    "`_sync_gradients` was incorrectly set to "
1072                    "`True` while in the `no_sync()` context manager"
1073                )
1074                m._sync_gradients = old_flag
1075
1076    @torch.no_grad()
1077    def clip_grad_norm_(
1078        self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
1079    ) -> torch.Tensor:
1080        """Clip the gradient norm of all parameters.
1081
1082        The norm is computed over all parameters' gradients as viewed as a single vector, and the
1083        gradients are modified in-place.
1084
1085        Args:
1086            max_norm (float or int): max norm of the gradients
1087            norm_type (float or int): type of the used p-norm. Can be ``'inf'``
1088                for infinity norm.
1089
1090        Returns:
1091            Total norm of the parameters (viewed as a single vector).
1092
1093        If every FSDP instance uses ``NO_SHARD``, meaning that no
1094        gradients are sharded across ranks, then you may directly use
1095        :func:`torch.nn.utils.clip_grad_norm_`.
1096
1097        If at least some FSDP instance uses a sharded strategy (i.e.
1098        one other than ``NO_SHARD``), then you should use this method
1099        instead of :func:`torch.nn.utils.clip_grad_norm_` since this method
1100        handles the fact that gradients are sharded across ranks.
1101
1102        The total norm returned will have the "largest" dtype across
1103        all parameters/gradients as defined by PyTorch's type promotion
1104        semantics. For example, if *all* parameters/gradients use a low
1105        precision dtype, then the returned norm's dtype will be that low
1106        precision dtype, but if there exists at least one parameter/
1107        gradient using FP32, then the returned norm's dtype will be FP32.
1108
1109        .. warning:: This needs to be called on all ranks since it uses
1110            collective communications.
1111        """
1112        _lazy_init(self, self)
1113        if not self._is_root:
1114            raise RuntimeError(
1115                "`clip_grad_norm_()` should only be called on the root FSDP instance"
1116            )
1117        if self._zero_scalar is None:
1118            self._zero_scalar = torch.tensor(0.0, device=self.compute_device)
1119        self._assert_state(TrainingState.IDLE)
1120        # If every FSDP instance uses `NO_SHARD`, then we can directly use
1121        # the normal `nn.utils` one targeting local gradients
1122        all_no_shard = all(
1123            not handle.uses_sharded_strategy for handle in self._all_handles
1124        )
1125        if all_no_shard:
1126            return torch.nn.utils.clip_grad_norm_(
1127                self.parameters(), max_norm, norm_type
1128            )
1129        # Otherwise, there exists some FSDP instance using a sharded strategy,
1130        # where sharded and non-sharded parameters must be handled separately
1131        max_norm = float(max_norm)
1132        norm_type = float(norm_type)
1133        sharded_params_set = set()
1134        nonsharded_params_set = set()  # `NO_SHARD` or not FSDP-managed
1135        # Make sure to compute the local norm using lists for deterministic
1136        # iteration order and hence deterministic total norm computation
1137        sharded_params = []
1138        nonsharded_params = []
1139        grads: List[torch.Tensor] = []
1140        for handle in self._all_handles:
1141            if handle.uses_sharded_strategy:
1142                target_set = sharded_params_set
1143                target_list = sharded_params
1144            else:
1145                target_set = nonsharded_params_set
1146                target_list = nonsharded_params
1147            if handle._use_orig_params:
1148                for param in handle.flat_param._params:
1149                    if param not in target_set:
1150                        target_set.add(param)
1151                        target_list.append(param)
1152                        if param.grad is not None:
1153                            grads.append(param.grad)
1154            else:
1155                if handle.flat_param not in target_set:
1156                    target_set.add(handle.flat_param)
1157                    target_list.append(handle.flat_param)
1158                    if handle.flat_param.grad is not None:
1159                        grads.append(handle.flat_param.grad)
1160        for param in self.parameters():
1161            not_fsdp_managed = (
1162                param not in sharded_params_set and param not in nonsharded_params_set
1163            )
1164            if not_fsdp_managed:
1165                nonsharded_params_set.add(param)
1166                nonsharded_params.append(param)
1167                if param.grad is not None:
1168                    grads.append(param.grad)
1169        # Compute local norms (forced to be in FP32)
1170        local_sharded_norm = _get_grad_norm(
1171            sharded_params, norm_type, self._zero_scalar, self.compute_device
1172        )
1173        local_nonsharded_norm = (
1174            _get_grad_norm(
1175                nonsharded_params, norm_type, self._zero_scalar, self.compute_device
1176            )
1177            if nonsharded_params
1178            else None
1179        )
1180        # Reconstruct the total gradient norm depending on the norm type
1181        if norm_type == math.inf:
1182            total_norm = (
1183                torch.maximum(local_sharded_norm, local_nonsharded_norm)
1184                if local_nonsharded_norm is not None
1185                else local_sharded_norm
1186            )
1187            dist.all_reduce(
1188                total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group
1189            )
1190        else:
1191            total_norm = local_sharded_norm**norm_type
1192            dist.all_reduce(total_norm, group=self.process_group)
1193            # All-reducing the local non-sharded norm would count it an extra
1194            # world-size-many times
1195            if local_nonsharded_norm is not None:
1196                total_norm += local_nonsharded_norm**norm_type
1197            total_norm = total_norm ** (1.0 / norm_type)
1198        if self.cpu_offload.offload_params:
1199            total_norm = total_norm.cpu()
1200
1201        clip_coef = max_norm / (total_norm + 1e-6)
1202        # Multiplying by the clamped coefficient is meaningless when it is
1203        # equal to 1, but it avoids the host-device sync that would result from
1204        # `if clip_coef < 1`
1205        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
1206        for grad in grads:
1207            grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype))
1208        # Use the "largest" dtype by type promotion semantics to use the same
1209        # dtype as if we did not force local norm computation to be in FP32
1210        if len(grads) == 0:
1211            # If this rank has no gradients, then we must default to FP32
1212            # unless we use additional communication, which we prefer to avoid
1213            # since `clip_grad_norm_()` is called in the training loop
1214            warnings.warn(
1215                f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no "
1216                "gradients -- returning the total norm in the default dtype "
1217                f"{total_norm.dtype}"
1218            )  # warn since this is generally unexpected
1219            return total_norm
1220        total_norm_dtype = functools.reduce(
1221            torch.promote_types,
1222            [grad.dtype for grad in grads],
1223        )
1224        return total_norm.to(total_norm_dtype)
1225
1226    @staticmethod
1227    def _warn_optim_input(optim_input, *, stacklevel: int = 1):
1228        if optim_input is not None:
1229            warnings.warn(
1230                "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. "
1231                "You may remove it from your code without changing its functionality.",
1232                FutureWarning,
1233                stacklevel=stacklevel + 1,
1234            )
1235
1236    @staticmethod
1237    def _is_using_optim_input(optim_input, optim) -> bool:
1238        if optim_input is None and optim is None:
1239            # Use the default behavior of `optim_input``
1240            return True
1241        if optim_input is not None:
1242            # Use the `optim_input` code path
1243            return True
1244        # Use the `optim` code path
1245        return False
1246
1247    @staticmethod
1248    def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1):
1249        warnings.warn(
1250            f"``FullyShardedDataParallel.{curr}``is being deprecated and is "
1251            f"replaced by ``FullyShardedDataParallel.{new}``. "
1252            f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.",
1253            FutureWarning,
1254            stacklevel=stacklevel + 1,
1255        )
1256
1257    @staticmethod
1258    def _optim_state_dict_impl(
1259        model: torch.nn.Module,
1260        optim: torch.optim.Optimizer,
1261        optim_state_dict: Dict[str, Any],
1262        optim_input: Optional[
1263            Union[
1264                List[Dict[str, Any]],
1265                Iterable[torch.nn.Parameter],
1266            ]
1267        ] = None,
1268        rank0_only: bool = True,
1269        full_state_dict: bool = True,
1270        group: Optional[dist.ProcessGroup] = None,
1271        cpu_offload: bool = True,
1272        *,
1273        _stacklevel: int = 1,
1274    ) -> Dict[str, Any]:
1275        """Transform the state-dict of an optimizer corresponding to a sharded model.
1276
1277        This is the internal API that is used by all the optim_state_dict implementations.
1278        Given model, optim, the original optim_state_dict, this API removes the
1279        FSDP internal information and internal sharding from the optim_state_dict.
1280        """
1281        if full_state_dict:
1282            FullyShardedDataParallel._warn_optim_input(
1283                optim_input, stacklevel=_stacklevel + 1
1284            )
1285            using_optim_input = FullyShardedDataParallel._is_using_optim_input(
1286                optim_input,
1287                optim,
1288            )
1289        else:
1290            using_optim_input = False
1291            assert optim_input is None and not rank0_only
1292
1293        use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
1294            0
1295        ]._use_orig_params
1296        assert all(
1297            use_orig_params == m._use_orig_params
1298            for m in FullyShardedDataParallel.fsdp_modules(model)
1299        ), "Not all FSDP modules have the same _use_orig_params value"
1300
1301        return _optim_state_dict(
1302            model=model,
1303            optim=optim,
1304            optim_state_dict=optim_state_dict,
1305            optim_input=optim_input,
1306            rank0_only=rank0_only,
1307            shard_state=not full_state_dict,
1308            group=group,
1309            using_optim_input=using_optim_input,
1310            use_orig_params=use_orig_params,
1311            cpu_offload=cpu_offload,
1312        )
1313
1314    @staticmethod
1315    def _optim_state_dict_to_load_impl(
1316        optim_state_dict: Dict[str, Any],
1317        model: torch.nn.Module,
1318        optim_input: Optional[
1319            Union[
1320                List[Dict[str, Any]],
1321                Iterable[torch.nn.Parameter],
1322            ]
1323        ] = None,
1324        optim: Optional[torch.optim.Optimizer] = None,
1325        full_state_dict: bool = True,
1326        rank0_only: bool = False,
1327        is_named_optimizer: bool = False,
1328        group: Optional[dist.ProcessGroup] = None,
1329    ) -> Dict[str, Any]:
1330        """
1331        Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
1332
1333        This is the internal API that is used by all the load optim_state_dict implementations.
1334        Given model, optim, and the saved optim_state_dict, this API adds the FSDP
1335        internal information and internal sharding to the optim_state_dict.
1336        """
1337        if full_state_dict:
1338            FullyShardedDataParallel._warn_optim_input(optim_input)
1339            using_optim_input = FullyShardedDataParallel._is_using_optim_input(
1340                optim_input,
1341                optim,
1342            )
1343        else:
1344            using_optim_input = False
1345            assert optim_input is None and not rank0_only
1346
1347        use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
1348            0
1349        ]._use_orig_params
1350        assert all(
1351            use_orig_params == m._use_orig_params
1352            for m in FullyShardedDataParallel.fsdp_modules(model)
1353        ), "Not all FSDP modules have the same _use_orig_params value"
1354
1355        if rank0_only and dist.get_rank(group) > 0:
1356            optim_state_dict = {}
1357        sharded_osd = _flatten_optim_state_dict(
1358            optim_state_dict,
1359            model=model,
1360            use_orig_params=use_orig_params,
1361            optim=(optim if is_named_optimizer else None),
1362            rank0_only=rank0_only,
1363            group=group,
1364        )
1365        return _rekey_sharded_optim_state_dict(
1366            sharded_osd,
1367            model=model,
1368            optim=optim,
1369            optim_input=optim_input,
1370            using_optim_input=using_optim_input,
1371            is_named_optimizer=is_named_optimizer,
1372        )
1373
1374    @staticmethod
1375    def full_optim_state_dict(
1376        model: torch.nn.Module,
1377        optim: torch.optim.Optimizer,
1378        optim_input: Optional[
1379            Union[
1380                List[Dict[str, Any]],
1381                Iterable[torch.nn.Parameter],
1382            ]
1383        ] = None,
1384        rank0_only: bool = True,
1385        group: Optional[dist.ProcessGroup] = None,
1386    ) -> Dict[str, Any]:
1387        """Return the full optimizer state-dict.
1388
1389        Consolidates the full optimizer state on rank 0 and returns it
1390        as a :class:`dict` following the convention of
1391        :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
1392        and ``"param_groups"``. The flattened parameters in ``FSDP`` modules
1393        contained in ``model`` are mapped back to their unflattened parameters.
1394
1395        This needs to be called on all ranks since it uses
1396        collective communications. However, if ``rank0_only=True``, then
1397        the state dict is only populated on rank 0, and all other ranks
1398        return an empty :class:`dict`.
1399
1400        Unlike ``torch.optim.Optimizer.state_dict()``, this method
1401        uses full parameter names as keys instead of parameter IDs.
1402
1403        Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors
1404        contained in the optimizer state dict are not cloned, so there may
1405        be aliasing surprises. For best practices, consider saving the
1406        returned optimizer state dict immediately, e.g. using
1407        ``torch.save()``.
1408
1409        Args:
1410            model (torch.nn.Module): Root module (which may or may not be a
1411                :class:`FullyShardedDataParallel` instance) whose parameters
1412                were passed into the optimizer ``optim``.
1413            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
1414                parameters.
1415            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
1416                Input passed into the optimizer ``optim`` representing either a
1417                :class:`list` of parameter groups or an iterable of parameters;
1418                if ``None``, then this method assumes the input was
1419                ``model.parameters()``. This argument is deprecated, and there
1420                is no need to pass it in anymore. (Default: ``None``)
1421            rank0_only (bool): If ``True``, saves the populated :class:`dict`
1422                only on rank 0; if ``False``, saves it on all ranks. (Default:
1423                ``True``)
1424            group (dist.ProcessGroup): Model's process group or ``None`` if using
1425                the default process group. (Default: ``None``)
1426
1427        Returns:
1428            Dict[str, Any]: A :class:`dict` containing the optimizer state for
1429            ``model`` 's original unflattened parameters and including keys
1430            "state" and "param_groups" following the convention of
1431            :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``,
1432            then nonzero ranks return an empty :class:`dict`.
1433        """
1434        FullyShardedDataParallel._warn_legacy_optim_state_dict(
1435            "full_optim_state_dict",
1436            "optim_state_dict",
1437            stacklevel=2,
1438        )
1439        return FullyShardedDataParallel._optim_state_dict_impl(
1440            model=model,
1441            optim=optim,
1442            optim_state_dict=optim.state_dict(),
1443            optim_input=optim_input,
1444            rank0_only=rank0_only,
1445            group=group,
1446            full_state_dict=True,
1447            _stacklevel=2,
1448        )
1449
1450    @staticmethod
1451    def sharded_optim_state_dict(
1452        model: torch.nn.Module,
1453        optim: torch.optim.Optimizer,
1454        group: Optional[dist.ProcessGroup] = None,
1455    ) -> Dict[str, Any]:
1456        """Return the optimizer state-dict in its sharded form.
1457
1458        The API is similar to :meth:`full_optim_state_dict` but this API chunks
1459        all non-zero-dimension states to :class:`ShardedTensor` to save memory.
1460        This API should only be used when the model ``state_dict`` is derived
1461        with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``.
1462
1463        For the detailed usage, refer to :meth:`full_optim_state_dict`.
1464
1465        .. warning:: The returned state dict contains ``ShardedTensor`` and
1466            cannot be directly used by the regular ``optim.load_state_dict``.
1467        """
1468        FullyShardedDataParallel._warn_legacy_optim_state_dict(
1469            "sharded_optim_state_dict",
1470            "optim_state_dict",
1471            stacklevel=2,
1472        )
1473        return FullyShardedDataParallel._optim_state_dict_impl(
1474            model=model,
1475            optim=optim,
1476            optim_state_dict=optim.state_dict(),
1477            optim_input=None,
1478            rank0_only=False,
1479            full_state_dict=False,
1480            group=group,
1481            _stacklevel=2,
1482        )
1483
1484    @staticmethod
1485    def shard_full_optim_state_dict(
1486        full_optim_state_dict: Dict[str, Any],
1487        model: torch.nn.Module,
1488        optim_input: Optional[
1489            Union[
1490                List[Dict[str, Any]],
1491                Iterable[torch.nn.Parameter],
1492            ]
1493        ] = None,
1494        optim: Optional[torch.optim.Optimizer] = None,
1495    ) -> Dict[str, Any]:
1496        """Shard a full optimizer state-dict.
1497
1498        Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened
1499        parameters and restricts to only this rank's part of the optimizer state.
1500        The first argument should be the return value of :meth:`full_optim_state_dict`.
1501
1502        Example::
1503
1504            >>> # xdoctest: +SKIP("undefined variables")
1505            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1506            >>> model, optim = ...
1507            >>> full_osd = FSDP.full_optim_state_dict(model, optim)
1508            >>> torch.save(full_osd, PATH)
1509            >>> # Define new model with possibly different world size
1510            >>> new_model, new_optim = ...
1511            >>> full_osd = torch.load(PATH)
1512            >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
1513            >>> new_optim.load_state_dict(sharded_osd)
1514
1515        .. note:: Both :meth:`shard_full_optim_state_dict` and
1516            :meth:`scatter_full_optim_state_dict` may be used to get the
1517            sharded optimizer state dict to load. Assuming that the full
1518            optimizer state dict resides in CPU memory, the former requires
1519            each rank to have the full dict in CPU memory, where each rank
1520            individually shards the dict without any communication, while the
1521            latter requires only rank 0 to have the full dict in CPU memory,
1522            where rank 0 moves each shard to GPU memory (for NCCL) and
1523            communicates it to ranks appropriately. Hence, the former has
1524            higher aggregate CPU memory cost, while the latter has higher
1525            communication cost.
1526
1527        Args:
1528            full_optim_state_dict (Dict[str, Any]): Optimizer state dict
1529                corresponding to the unflattened parameters and holding the
1530                full non-sharded optimizer state.
1531            model (torch.nn.Module): Root module (which may or may not be a
1532                :class:`FullyShardedDataParallel` instance) whose parameters
1533                correspond to the optimizer state in ``full_optim_state_dict``.
1534            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
1535                Input passed into the optimizer representing either a
1536                :class:`list` of parameter groups or an iterable of parameters;
1537                if ``None``, then this method assumes the input was
1538                ``model.parameters()``. This argument is deprecated, and there
1539                is no need to pass it in anymore. (Default: ``None``)
1540            optim (Optional[torch.optim.Optimizer]): Optimizer that will load
1541                the state dict returned by this method. This is the preferred
1542                argument to use over ``optim_input``. (Default: ``None``)
1543
1544        Returns:
1545            Dict[str, Any]: The full optimizer state dict now remapped to
1546            flattened parameters instead of unflattened parameters and
1547            restricted to only include this rank's part of the optimizer state.
1548        """
1549        FullyShardedDataParallel._warn_legacy_optim_state_dict(
1550            "shard_full_optim_state_dict",
1551            "optim_state_dict_to_load",
1552            stacklevel=2,
1553        )
1554        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
1555            optim_state_dict=full_optim_state_dict,
1556            model=model,
1557            optim_input=optim_input,
1558            optim=optim,
1559            full_state_dict=True,
1560            is_named_optimizer=False,
1561        )
1562
1563    @staticmethod
1564    def flatten_sharded_optim_state_dict(
1565        sharded_optim_state_dict: Dict[str, Any],
1566        model: torch.nn.Module,
1567        optim: torch.optim.Optimizer,
1568    ) -> Dict[str, Any]:
1569        """Flatten a sharded optimizer state-dict.
1570
1571        The API is similar to :meth:`shard_full_optim_state_dict`. The only
1572        difference is that the input ``sharded_optim_state_dict`` should be
1573        returned from :meth:`sharded_optim_state_dict`. Therefore, there will
1574        be all-gather calls on each rank to gather ``ShardedTensor`` s.
1575
1576        Args:
1577            sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict
1578                corresponding to the unflattened parameters and holding the
1579                sharded optimizer state.
1580            model (torch.nn.Module):
1581                Refer to :meth:`shard_full_optim_state_dict`.
1582            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
1583                parameters.
1584
1585        Returns:
1586            Refer to :meth:`shard_full_optim_state_dict`.
1587        """
1588        FullyShardedDataParallel._warn_legacy_optim_state_dict(
1589            "flatten_sharded_optim_state_dict",
1590            "optim_state_dict_to_load",
1591            stacklevel=2,
1592        )
1593        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
1594            optim_state_dict=sharded_optim_state_dict,
1595            model=model,
1596            optim_input=None,
1597            optim=optim,
1598            full_state_dict=False,
1599            is_named_optimizer=False,
1600        )
1601
1602    @staticmethod
1603    def scatter_full_optim_state_dict(
1604        full_optim_state_dict: Optional[Dict[str, Any]],
1605        model: torch.nn.Module,
1606        optim_input: Optional[
1607            Union[
1608                List[Dict[str, Any]],
1609                Iterable[torch.nn.Parameter],
1610            ]
1611        ] = None,
1612        optim: Optional[torch.optim.Optimizer] = None,
1613        group: Optional[Any] = None,
1614    ) -> Dict[str, Any]:
1615        """Scatter the full optimizer state dict from rank 0 to all other ranks.
1616
1617        Returns the sharded optimizer state dict on each rank.
1618        The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank
1619        0, the first argument should be the return value of
1620        :meth:`full_optim_state_dict`.
1621
1622        Example::
1623
1624            >>> # xdoctest: +SKIP("undefined variables")
1625            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1626            >>> model, optim = ...
1627            >>> full_osd = FSDP.full_optim_state_dict(model, optim)  # only non-empty on rank 0
1628            >>> # Define new model with possibly different world size
1629            >>> new_model, new_optim, new_group = ...
1630            >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
1631            >>> new_optim.load_state_dict(sharded_osd)
1632
1633        .. note:: Both :meth:`shard_full_optim_state_dict` and
1634            :meth:`scatter_full_optim_state_dict` may be used to get the
1635            sharded optimizer state dict to load. Assuming that the full
1636            optimizer state dict resides in CPU memory, the former requires
1637            each rank to have the full dict in CPU memory, where each rank
1638            individually shards the dict without any communication, while the
1639            latter requires only rank 0 to have the full dict in CPU memory,
1640            where rank 0 moves each shard to GPU memory (for NCCL) and
1641            communicates it to ranks appropriately. Hence, the former has
1642            higher aggregate CPU memory cost, while the latter has higher
1643            communication cost.
1644
1645        Args:
1646            full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state
1647                dict corresponding to the unflattened parameters and holding
1648                the full non-sharded optimizer state if on rank 0; the argument
1649                is ignored on nonzero ranks.
1650            model (torch.nn.Module): Root module (which may or may not be a
1651                :class:`FullyShardedDataParallel` instance) whose parameters
1652                correspond to the optimizer state in ``full_optim_state_dict``.
1653            optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
1654                Input passed into the optimizer representing either a
1655                :class:`list` of parameter groups or an iterable of parameters;
1656                if ``None``, then this method assumes the input was
1657                ``model.parameters()``. This argument is deprecated, and there
1658                is no need to pass it in anymore. (Default: ``None``)
1659            optim (Optional[torch.optim.Optimizer]): Optimizer that will load
1660                the state dict returned by this method. This is the preferred
1661                argument to use over ``optim_input``. (Default: ``None``)
1662            group (dist.ProcessGroup): Model's process group or ``None`` if
1663                using the default process group. (Default: ``None``)
1664
1665        Returns:
1666            Dict[str, Any]: The full optimizer state dict now remapped to
1667            flattened parameters instead of unflattened parameters and
1668            restricted to only include this rank's part of the optimizer state.
1669        """
1670        FullyShardedDataParallel._warn_legacy_optim_state_dict(
1671            "scatter_full_optim_state_dict",
1672            "optim_state_dict_to_load",
1673            stacklevel=2,
1674        )
1675        return FullyShardedDataParallel._optim_state_dict_to_load_impl(
1676            optim_state_dict=full_optim_state_dict,
1677            model=model,
1678            optim_input=optim_input,
1679            optim=optim,
1680            full_state_dict=True,
1681            rank0_only=True,
1682            is_named_optimizer=False,
1683            group=group,
1684        )
1685
1686    @staticmethod
1687    def rekey_optim_state_dict(
1688        optim_state_dict: Dict[str, Any],
1689        optim_state_key_type: OptimStateKeyType,
1690        model: torch.nn.Module,
1691        optim_input: Optional[
1692            Union[
1693                List[Dict[str, Any]],
1694                Iterable[torch.nn.Parameter],
1695            ]
1696        ] = None,
1697        optim: Optional[torch.optim.Optimizer] = None,
1698    ) -> Dict[str, Any]:
1699        """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``.
1700
1701        This can be used to achieve compatibility between optimizer state dicts from models with FSDP
1702        instances and ones without.
1703
1704        To re-key an FSDP full optimizer state dict (i.e. from
1705        :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to
1706        a non-wrapped model::
1707
1708            >>> # xdoctest: +SKIP("undefined variables")
1709            >>> wrapped_model, wrapped_optim = ...
1710            >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
1711            >>> nonwrapped_model, nonwrapped_optim = ...
1712            >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
1713            >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
1714
1715        To re-key a normal optimizer state dict from a non-wrapped model to be
1716        loadable to a wrapped model::
1717
1718            >>> # xdoctest: +SKIP("undefined variables")
1719            >>> nonwrapped_model, nonwrapped_optim = ...
1720            >>> osd = nonwrapped_optim.state_dict()
1721            >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
1722            >>> wrapped_model, wrapped_optim = ...
1723            >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
1724            >>> wrapped_optim.load_state_dict(sharded_osd)
1725
1726        Returns:
1727            Dict[str, Any]: The optimizer state dict re-keyed using the
1728            parameter keys specified by ``optim_state_key_type``.
1729        """
1730        FullyShardedDataParallel._warn_optim_input(optim_input)
1731        using_optim_input = FullyShardedDataParallel._is_using_optim_input(
1732            optim_input,
1733            optim,
1734        )
1735        assert optim_state_key_type in (
1736            OptimStateKeyType.PARAM_NAME,
1737            OptimStateKeyType.PARAM_ID,
1738        )
1739        osd = optim_state_dict  # alias
1740        # Validate that the existing parameter keys are uniformly typed
1741        uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
1742        uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]]
1743        if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or (
1744            any(uses_param_id_mask) and not all(uses_param_id_mask)
1745        ):
1746            error_msg = f"Invalid parameter keys: {osd['state'].keys()}"
1747            raise ValueError(error_msg)
1748        # Return directly if the existing key type matches the target key type
1749        if (
1750            optim_state_key_type == OptimStateKeyType.PARAM_NAME
1751            and all(uses_param_name_mask)
1752        ) or (
1753            optim_state_key_type == OptimStateKeyType.PARAM_ID
1754            and all(uses_param_id_mask)
1755        ):
1756            return osd
1757        # Otherwise, actually perform the re-keying
1758        new_osd = {}
1759        if optim_state_key_type == OptimStateKeyType.PARAM_NAME:  # ID -> name
1760            param_id_to_param = (
1761                _get_param_id_to_param_from_optim_input(model, optim_input)
1762                if using_optim_input
1763                else _get_param_key_to_param(optim)
1764            )
1765            param_to_param_name = _get_param_to_fqn(model)
1766            param_id_to_param_name: List[str] = [
1767                param_to_param_name[param] for param in param_id_to_param.values()
1768            ]
1769            new_osd["state"] = {
1770                param_id_to_param_name[param_id]: param_state
1771                for param_id, param_state in osd["state"].items()
1772            }
1773            new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
1774            for param_group in new_osd["param_groups"]:
1775                param_group["params"] = sorted(
1776                    [
1777                        param_id_to_param_name[param_id]
1778                        for param_id in param_group["params"]
1779                    ]
1780                )
1781            return new_osd
1782        elif optim_state_key_type == OptimStateKeyType.PARAM_ID:  # name -> ID
1783            param_name_to_param = _get_fqn_to_param(model)
1784            param_to_param_id = (
1785                _get_param_to_param_id_from_optim_input(model, optim_input)
1786                if using_optim_input
1787                else _get_param_to_param_key(optim)
1788            )
1789            # Because not all model parameters may be passed as the optimizer
1790            # input, we may need to drop some parameters from this mapping
1791            param_name_to_param_id = {
1792                param_name: param_to_param_id[param]
1793                for param_name, param in param_name_to_param.items()
1794                if param in param_to_param_id
1795            }
1796            new_osd["state"] = {
1797                param_name_to_param_id[param_name]: param_state
1798                for param_name, param_state in osd["state"].items()
1799            }
1800            new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
1801            for param_group in new_osd["param_groups"]:
1802                param_group["params"] = sorted(
1803                    [
1804                        param_name_to_param_id[param_name]
1805                        for param_name in param_group["params"]
1806                    ]
1807                )
1808            return new_osd
1809        return new_osd  # should never reach here
1810
1811    @staticmethod
1812    def optim_state_dict(
1813        model: torch.nn.Module,
1814        optim: torch.optim.Optimizer,
1815        optim_state_dict: Optional[Dict[str, Any]] = None,
1816        group: Optional[dist.ProcessGroup] = None,
1817    ) -> Dict[str, Any]:
1818        """
1819        Transform the state-dict of an optimizer corresponding to a sharded model.
1820
1821        The given state-dict can be transformed to one of three types:
1822        1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict.
1823
1824        For full optimizer state_dict, all states are unflattened and not sharded.
1825        Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
1826        avoid OOM.
1827
1828        For sharded optimizer state_dict, all states are unflattened but sharded.
1829        CPU only can be specified via :meth:`state_dict_type` to further save
1830        memory.
1831
1832        For local state_dict, no transformation will be performed. But a state
1833        will be converted from nn.Tensor to ShardedTensor to represent its sharding
1834        nature (this is not supported yet).
1835
1836        Example::
1837
1838            >>> # xdoctest: +SKIP("undefined variables")
1839            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1840            >>> from torch.distributed.fsdp import StateDictType
1841            >>> from torch.distributed.fsdp import FullStateDictConfig
1842            >>> from torch.distributed.fsdp import FullOptimStateDictConfig
1843            >>> # Save a checkpoint
1844            >>> model, optim = ...
1845            >>> FSDP.set_state_dict_type(
1846            >>>     model,
1847            >>>     StateDictType.FULL_STATE_DICT,
1848            >>>     FullStateDictConfig(rank0_only=False),
1849            >>>     FullOptimStateDictConfig(rank0_only=False),
1850            >>> )
1851            >>> state_dict = model.state_dict()
1852            >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
1853            >>> save_a_checkpoint(state_dict, optim_state_dict)
1854            >>> # Load a checkpoint
1855            >>> model, optim = ...
1856            >>> state_dict, optim_state_dict = load_a_checkpoint()
1857            >>> FSDP.set_state_dict_type(
1858            >>>     model,
1859            >>>     StateDictType.FULL_STATE_DICT,
1860            >>>     FullStateDictConfig(rank0_only=False),
1861            >>>     FullOptimStateDictConfig(rank0_only=False),
1862            >>> )
1863            >>> model.load_state_dict(state_dict)
1864            >>> optim_state_dict = FSDP.optim_state_dict_to_load(
1865            >>>     model, optim, optim_state_dict
1866            >>> )
1867            >>> optim.load_state_dict(optim_state_dict)
1868
1869        Args:
1870            model (torch.nn.Module): Root module (which may or may not be a
1871                :class:`FullyShardedDataParallel` instance) whose parameters
1872                were passed into the optimizer ``optim``.
1873            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
1874                parameters.
1875            optim_state_dict (Dict[str, Any]): the target optimizer state_dict to
1876                transform. If the value is None, optim.state_dict() will be used. (
1877                Default: ``None``)
1878            group (dist.ProcessGroup): Model's process group across which parameters
1879                are sharded or ``None`` if using the default process group. (
1880                Default: ``None``)
1881
1882        Returns:
1883            Dict[str, Any]: A :class:`dict` containing the optimizer state for
1884            ``model``. The sharding of the optimizer state is based on
1885            ``state_dict_type``.
1886        """
1887        state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
1888        if optim_state_dict is None:
1889            optim_state_dict = optim.state_dict()
1890        return FullyShardedDataParallel._optim_state_dict_impl(
1891            model=model,
1892            optim=optim,
1893            optim_state_dict=optim_state_dict,
1894            optim_input=None,
1895            rank0_only=getattr(
1896                state_dict_settings.optim_state_dict_config, "rank0_only", False
1897            ),
1898            full_state_dict=state_dict_settings.state_dict_type
1899            == StateDictType.FULL_STATE_DICT,
1900            group=group,
1901            cpu_offload=getattr(
1902                state_dict_settings.optim_state_dict_config, "offload_to_cpu", True
1903            ),
1904            _stacklevel=2,
1905        )
1906
1907    @staticmethod
1908    def optim_state_dict_to_load(
1909        model: torch.nn.Module,
1910        optim: torch.optim.Optimizer,
1911        optim_state_dict: Dict[str, Any],
1912        is_named_optimizer: bool = False,
1913        load_directly: bool = False,
1914        group: Optional[dist.ProcessGroup] = None,
1915    ) -> Dict[str, Any]:
1916        """
1917        Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model.
1918
1919        Given a ``optim_state_dict`` that is transformed through
1920        :meth:`optim_state_dict`, it gets converted to the flattened optimizer
1921        state_dict that can be loaded to ``optim`` which is the optimizer for
1922        ``model``. ``model`` must be sharded by FullyShardedDataParallel.
1923
1924            >>> # xdoctest: +SKIP("undefined variables")
1925            >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1926            >>> from torch.distributed.fsdp import StateDictType
1927            >>> from torch.distributed.fsdp import FullStateDictConfig
1928            >>> from torch.distributed.fsdp import FullOptimStateDictConfig
1929            >>> # Save a checkpoint
1930            >>> model, optim = ...
1931            >>> FSDP.set_state_dict_type(
1932            >>>     model,
1933            >>>     StateDictType.FULL_STATE_DICT,
1934            >>>     FullStateDictConfig(rank0_only=False),
1935            >>>     FullOptimStateDictConfig(rank0_only=False),
1936            >>> )
1937            >>> state_dict = model.state_dict()
1938            >>> original_osd = optim.state_dict()
1939            >>> optim_state_dict = FSDP.optim_state_dict(
1940            >>>     model,
1941            >>>     optim,
1942            >>>     optim_state_dict=original_osd
1943            >>> )
1944            >>> save_a_checkpoint(state_dict, optim_state_dict)
1945            >>> # Load a checkpoint
1946            >>> model, optim = ...
1947            >>> state_dict, optim_state_dict = load_a_checkpoint()
1948            >>> FSDP.set_state_dict_type(
1949            >>>     model,
1950            >>>     StateDictType.FULL_STATE_DICT,
1951            >>>     FullStateDictConfig(rank0_only=False),
1952            >>>     FullOptimStateDictConfig(rank0_only=False),
1953            >>> )
1954            >>> model.load_state_dict(state_dict)
1955            >>> optim_state_dict = FSDP.optim_state_dict_to_load(
1956            >>>     model, optim, optim_state_dict
1957            >>> )
1958            >>> optim.load_state_dict(optim_state_dict)
1959
1960        Args:
1961            model (torch.nn.Module): Root module (which may or may not be a
1962                :class:`FullyShardedDataParallel` instance) whose parameters
1963                were passed into the optimizer ``optim``.
1964            optim (torch.optim.Optimizer): Optimizer for ``model`` 's
1965                parameters.
1966            optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
1967            is_named_optimizer (bool): Is this optimizer a NamedOptimizer or
1968                KeyedOptimizer. Only set to True if ``optim`` is TorchRec's
1969                KeyedOptimizer or torch.distributed's NamedOptimizer.
1970            load_directly (bool): If this is set to True, this API will also
1971                call optim.load_state_dict(result) before returning the result.
1972                Otherwise, users are responsible to call ``optim.load_state_dict()``
1973                (Default: ``False``)
1974            group (dist.ProcessGroup): Model's process group across which parameters
1975                are sharded or ``None`` if using the default process group. (
1976                Default: ``None``)
1977        """
1978        state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
1979        result = FullyShardedDataParallel._optim_state_dict_to_load_impl(
1980            optim_state_dict=optim_state_dict,
1981            model=model,
1982            optim_input=None,
1983            optim=optim,
1984            full_state_dict=(
1985                state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT
1986            ),
1987            rank0_only=getattr(
1988                state_dict_settings.optim_state_dict_config, "rank0_only", False
1989            ),
1990            is_named_optimizer=is_named_optimizer,
1991            group=group,
1992        )
1993        if load_directly:
1994            optim.load_state_dict(result)
1995        return result
1996
1997    def register_comm_hook(self, state: object, hook: callable):
1998        """Register a communication hook.
1999
2000        This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates
2001        gradients across multiple workers.
2002        This hook can be used to implement several algorithms like
2003        `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
2004        which involve different communication strategies for
2005        parameter syncs while training with :class:`FullyShardedDataParallel`.
2006
2007        .. warning ::
2008            FSDP communication hook should be registered before running an initial forward pass
2009            and only once.
2010
2011        Args:
2012            state (object): Passed to the hook to maintain any state information during the training process.
2013                            Examples include error feedback in gradient compression,
2014                            peers to communicate with next in `GossipGrad <https://arxiv.org/abs/1803.05880>`_, etc.
2015                            It is locally stored by each worker
2016                            and shared by all the gradient tensors on the worker.
2017            hook (Callable): Callable, which has one of the following signatures:
2018                            1) ``hook: Callable[torch.Tensor] -> None``:
2019                            This function takes in a Python tensor, which represents
2020                            the full, flattened, unsharded gradient with respect to all variables
2021                            corresponding to the model this FSDP unit is wrapping
2022                            (that are not wrapped by other FSDP sub-units).
2023                            It then performs all necessary processing and returns ``None``;
2024                            2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
2025                            This function takes in two Python tensors, the first one represents
2026                            the full, flattened, unsharded gradient with respect to all variables
2027                            corresponding to the model this FSDP unit is wrapping
2028                            (that are not wrapped by other FSDP sub-units). The latter
2029                            represents a pre-sized tensor to store a chunk of a sharded gradient after
2030                            reduction.
2031                            In both cases, callable performs all necessary processing and returns ``None``.
2032                            Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
2033                            Callables with signature 2 are expected to handle gradient communication for sharded cases.
2034
2035        """
2036        if not self.check_is_root():
2037            raise AssertionError(
2038                "register_comm_hook can only be called on a root instance."
2039            )
2040        for fsdp_state in traversal_utils._get_fsdp_states(self):
2041            if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
2042                raise AssertionError(
2043                    f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}"
2044                )
2045            if fsdp_state._comm_hook is not None:
2046                raise AssertionError("A communication hook is already registered")
2047            if not callable(hook):
2048                raise ValueError(
2049                    f"The communication hook must be callable but got {hook}"
2050                )
2051            fsdp_state._comm_hook = hook
2052            fsdp_state._comm_hook_state = state
2053
2054    def _unshard(self, async_op: bool = False):
2055        class UnshardHandle:
2056            def __init__(
2057                self,
2058                flat_param_handle: Optional[FlatParamHandle],
2059                unshard_event: torch.Event,
2060            ):
2061                self._flat_param_handle = flat_param_handle
2062                self._unshard_event = unshard_event
2063
2064            def wait(self):
2065                if self._flat_param_handle is not None:
2066                    current_stream = (
2067                        self._flat_param_handle._device_handle.current_stream()
2068                    )
2069                    current_stream.wait_event(self._unshard_event)
2070                    self._flat_param_handle = None
2071
2072        if self._handle:
2073            with self._use_training_state(
2074                TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD
2075            ):
2076                _unshard(
2077                    self, self._handle, self._unshard_stream, self._pre_unshard_stream
2078                )
2079                self._unshard_event = self._unshard_stream.record_event()
2080            self._handle._prefetched = True
2081        unshard_handle = UnshardHandle(self._handle, self._unshard_stream)
2082        if async_op:
2083            return unshard_handle
2084        unshard_handle.wait()
2085        return None
2086
2087    def _wait_unshard_streams_on_current_stream(self):
2088        _wait_for_computation_stream(
2089            self._device_handle.current_stream(),
2090            self._unshard_stream,
2091            self._pre_unshard_stream,
2092        )
2093
2094    @contextlib.contextmanager
2095    def _use_training_state(
2096        self, training_state: TrainingState, handle_training_state: HandleTrainingState
2097    ):
2098        prev_training_state = self.training_state
2099        self.training_state = training_state
2100        if self._handle:
2101            prev_handle_training_state = self._handle._training_state
2102            self._handle._training_state = handle_training_state
2103        try:
2104            yield
2105        finally:
2106            self.training_state = prev_training_state
2107            if self._handle:
2108                self._handle._training_state = prev_handle_training_state
2109
2110
2111def _get_grad_norm(
2112    params: Iterable[nn.Parameter],
2113    norm_type: float,
2114    zero: torch.Tensor,
2115    device: torch.device,
2116) -> torch.Tensor:
2117    """
2118    Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector.
2119
2120    The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream
2121    use of this return value is a reduction across ranks.
2122    """
2123    params_with_grad = [param for param in params if param.grad is not None]
2124    if len(params_with_grad) == 0:
2125        # Reuse a tensor for zero to avoid a GPU sync
2126        return zero
2127    grads = [param.grad for param in params_with_grad]
2128    grad_dtypes = {grad.dtype for grad in grads}
2129    if len(grad_dtypes) != 1:
2130        raise ValueError(
2131            f"Requires uniform dtype across all gradients but got {grad_dtypes}"
2132        )
2133    # Compute the gradient norm in FP32, where we treat the gradients as a
2134    # single vector
2135    grad_norm = torch.linalg.vector_norm(
2136        torch.stack(
2137            [
2138                torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32)
2139                for grad in grads
2140            ],
2141        ),
2142        norm_type,
2143        dtype=torch.float32,
2144    )
2145    return grad_norm.to(device=device)
2146
2147
2148def _get_param_to_fqn(
2149    model: torch.nn.Module,
2150) -> Dict[torch.nn.Parameter, str]:
2151    """
2152    Construct a mapping from parameters to their parameter names.
2153
2154    The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which
2155    means that none of the parameters should be ``FlatParameter`` s. As a
2156    result, compared to :meth:`_get_param_to_fqns`, the mapped
2157    values may be flattened from singleton :class:`list` s to the contained
2158    names themselves.
2159
2160    Args:
2161        model (torch.nn.Module): Root module, which should not contain any
2162            :class:`FullyShardedDataParallel` instances.
2163    """
2164    param_to_param_names = _get_param_to_fqns(model)
2165    for param_names in param_to_param_names.values():
2166        assert (
2167            len(param_names) > 0
2168        ), "`_get_param_to_fqns()` should not construct empty lists"
2169        if len(param_names) > 1:
2170            raise RuntimeError(
2171                "Each parameter should only map to one parameter name but got "
2172                f"{len(param_names)}: {param_names}"
2173            )
2174    param_to_param_name = {
2175        param: param_names[0] for param, param_names in param_to_param_names.items()
2176    }
2177    return param_to_param_name
2178
2179
2180def _get_fqn_to_param(
2181    model: torch.nn.Module,
2182) -> Dict[str, torch.nn.Parameter]:
2183    """Construct the inverse mapping of :meth:`_get_param_to_fqn`."""
2184    param_to_param_name = _get_param_to_fqn(model)
2185    return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))
2186