xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_flat_param.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4import logging
5import os
6import warnings
7from enum import auto, Enum
8from itertools import accumulate, chain
9from typing import (
10    Any,
11    Callable,
12    cast,
13    Dict,
14    Generator,
15    Iterator,
16    List,
17    NamedTuple,
18    no_type_check,
19    Optional,
20    Sequence,
21    Set,
22    Tuple,
23    Union,
24)
25
26import torch
27import torch.distributed as dist
28import torch.nn as nn
29import torch.nn.functional as F
30from torch import Tensor
31from torch.distributed.fsdp._common_utils import (
32    _FSDPDeviceHandle,
33    _named_parameters_with_duplicates,
34    _no_dispatch_record_stream,
35    _set_fsdp_flattened,
36    HandleTrainingState,
37)
38from torch.distributed.utils import (
39    _alloc_storage,
40    _data_ptr_allocated,
41    _free_storage,
42    _p_assert,
43)
44from torch.nn.parameter import _ParameterMeta  # type: ignore[attr-defined]
45from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
46
47from ._fsdp_extensions import (
48    _ext_post_unflatten_transform,
49    _ext_pre_flatten_transform,
50    FSDPExtensions,
51)
52
53
54__all__ = [
55    "FlatParameter",
56    "FlatParamHandle",
57    "FlatParamShardMetadata",
58    "ParamInfo",
59    "SharedParamInfo",
60    "HandleShardingStrategy",
61]
62
63logger = logging.getLogger(__name__)
64
65
66"""
67[Note: Fully Sharded Module]
68We define the "fully sharded module" to be the original ``nn.Module`` that owns
69a ``FlatParamHandle``. It is the *single* module logically responsible for the
70*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
71forward or backward pass. The fully sharded module should be passed to the
72``FlatParamHandle`` constructor.
73
74For the wrapper code path:
75- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
76runs the unshard/reshard on behalf of the fully sharded module by overriding
77``nn.Module.forward``.
78- The fully sharded module is exactly the module passed to the
79``FullyShardedDataParallel`` constructor's ``module`` argument.
80
81For the non-wrapper code path:
82- Hooks registered on the fully sharded module run the unshard/reshard.
83- The fully sharded module may either be the direct argument to ``fully_shard``
84or a submodule chosen by the provided wrapping policy.
85"""
86
87# Environment variable toggling whether to use unsafe `setattr()` for view
88# setting in `_use_sharded_views()` and `_use_unsharded_views()`
89# We should use 'safe' by default since it respects method overrides, but for
90# special cases such as for high CPU overhead or for intentionally bypassing
91# checks in the overrides, we may use 'unsafe'.
92_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
93
94# Environment variable toggling whether to check for parameter/gradient
95# writeback in case their storages change after FSDP initialization
96# We should check by default since it prevents silent correctness errors, but
97# since such changes are atypical, we may want to skip the check to save CPU
98# overhead, especially since the check happens in the pre-forward and
99# pre-backward each iteration.
100_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK"
101
102# Env var toggling whether when model is in .eval() mode, should we run in fp32
103# or the reduced precision.
104_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL"
105
106# Some value to set padding in tensors to for debuggability
107_FLAT_PARAM_PADDING_VALUE = 42
108
109# Environment variables for disabling the all-gather and reduce-scatter
110# communication ops for ablation studies. Note that without these communication
111# ops the training won't converge, and you probably need to disable correctness
112# checks in your model.
113_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER"
114_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE"
115
116
117# TODO: Define this for now to avoid circular imports. See if we can remove.
118class HandleShardingStrategy(Enum):
119    FULL_SHARD = auto()
120    SHARD_GRAD_OP = auto()
121    NO_SHARD = auto()
122    HYBRID_SHARD = auto()
123    _HYBRID_SHARD_ZERO2 = auto()
124
125
126RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
127    HandleShardingStrategy.FULL_SHARD,
128    HandleShardingStrategy.HYBRID_SHARD,
129)
130NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = (
131    HandleShardingStrategy.SHARD_GRAD_OP,
132    HandleShardingStrategy._HYBRID_SHARD_ZERO2,
133)
134
135
136class ParamInfo(NamedTuple):
137    """Information for an original parameter."""
138
139    param_name: str  # unprefixed
140    module: nn.Module
141    module_name: str
142
143
144class SharedParamInfo(NamedTuple):
145    """
146    Additional information for a shared parameter.
147
148    For each shared parameter, we designate one module and its parameter
149    variable to be the primary owner, determined as the first one encountered
150    in the parameter walk. These are prefixed with "prim". The primary module
151    and parameter do not have their own :class:`SharedParamInfo` instance.
152    """
153
154    param_name: str  # unprefixed
155    module: nn.Module
156    module_name: str
157    prim_param_name: str  # unprefixed
158    prim_module: nn.Module
159    prim_module_name: str
160
161
162class _ShardParamInfo(NamedTuple):
163    """Shard-related information for an original parameter."""
164
165    in_shard: bool
166    # Use to index into the sharded flat parameter, e.g.
167    # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
168    offset_in_shard: Optional[int]
169    numel_in_shard: Optional[int]
170    # Use to get part of the parameter in the local shard from a flattened
171    # version of the unsharded parameter, e.g.
172    # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
173    intra_param_start_idx: Optional[int]
174    intra_param_end_idx: Optional[int]  # inclusive
175
176
177class FlatParamShardMetadata(NamedTuple):
178    """
179    This holds metadata specific to this rank's shard of the flat parameter.
180
181    Attributes:
182        param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
183            shard of the parameters; see :class:`FlatParameter`.
184        param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
185            shard of the parameters; see :class:`FlatParameter`.
186        param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
187            of the parameters; see :class:`FlatParameter`.
188        param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
189            units of numels) giving this rank's part of each flattened
190            original parameter.
191    """
192
193    param_names: Tuple[str, ...]
194    param_shapes: Tuple[torch.Size, ...]
195    param_numels: Tuple[int, ...]
196    param_offsets: Tuple[Tuple[int, int], ...]
197
198
199class _FlatParameterMeta(_ParameterMeta):
200    # Make `isinstance(t, FlatParameter)` return True for custom tensor
201    # instances that have the _is_flat_param flag for BC
202    def __instancecheck__(self, instance):
203        # NB: do NOT test the super implementation
204        return isinstance(instance, torch.Tensor) and getattr(
205            instance, "_is_flat_param", False
206        )
207
208
209class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
210    """
211    This is the flat parameter used by :class:`FullyShardedDataParallel`.
212
213    It is comprised of one or more original parameters, which are flattened and
214    concatenated to construct the flat parameter.
215
216    Under the current design, this parameter logically represents both the
217    unsharded and sharded flat parameter, and its data changes storages
218    dynamically.
219        - In the :class:`FullyShardedDataParallel` constructor, the parameter
220        is initialized as unsharded and then sharded in-place.
221        - At runtime, the parameter is lazily (re)-initialized. The sharded
222        parameter data is saved in ``self._local_shard``, and a new ``Tensor``
223        ``self._full_param_padded`` is created, which is the all-gather
224        destination and owns the unsharded parameter storage thereafter. (See
225        :meth:`FlatParamHandle.init_flat_param_attributes`.)
226        - Throughout runtime, the parameter data changes storages as needed,
227        e.g. to the sharded flat parameter, low precision sharded flat
228        parameter, or the unsharded flat parameter.
229
230    NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
231    padding, we have two versions of the per-parameter numels, one that
232    includes the padding (``_numels_with_padding``) and one that does not
233    (``_numels``). The former may have length longer than the other data
234    structures, while the latter has the same length as the number of actual
235    original parameters like the other per-parameter data structures.
236
237    NOTE: This is not a real class; instead, you will always get a Parameter
238    back out if you try to create one of these.  This is similar to the trick
239    we implemented for Parameter to get it to work with subclasses; this
240    is primarily so that FlatParameter supports combination with FakeTensor.
241
242    Attributes:
243        _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
244            without right-hand-side padding for divisibility by the world size.
245            For ``use_orig_params=True``, this includes alignment padding.
246        _padded_unsharded_size (torch.Size): Unsharded flat parameter's size
247            with right-hand-side padding for divisibility by the world size.
248            For ``use_orig_params=True``, this includes alignment padding. This
249            is only set for sharded strategies since they require padding for
250            the all-gather.
251        _sharded_size (torch.Size): Sharded flat parameter's size with padding.
252            This is also set for ``NO_SHARD``, in which case it is the same as
253            the unsharded sizes. (We omit "padded" because there is no
254            analogous unpadded one.)
255
256        _num_params (int): Number of original parameters flattened into this
257            flat parameter. This is the length of the per-parameter data
258            structures.
259        _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
260            entry; see :class:`ParamInfo` for details.
261        _shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
262        _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
263            prefixed from the ``_fully_sharded_module``. The names are
264            guaranteed to be unique in the subtree rooted at that module.
265        _param_extensions (Tuple[Optional[Any], ...]): Each parameter's
266            extension (i.e. some per-parameter state) used to customize
267            pre-flatten and post-unflatten behavior or ``None``. This is
268            experimental, and users should not depend on its existence in the
269            future.
270        _numels_with_padding (Tuple[int, ...]): Each parameter's numel
271            including entries for the padding. This is used to construct views
272            into the flat parameter via ``torch.split()``. This may have length
273            longer than ``_num_params``.
274        _numels (Tuple[int, ...]): Each parameter's numel excluding entries for
275            padding. This has length equal to ``_num_params``.
276        _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
277            shard parameter info; see :class:`_ShardParamInfo` for details.
278        _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
279            info entries; see :class:`SharedParamInfo` for details.
280        _modules (Set[nn.Module]): Modules that contain some original parameter
281            that is flattened into the flat parameter.
282
283        _shard_numel_padded (int): Numel padded for this rank's sharded flat
284            parameter.
285        _local_shard (Tensor): Sharded flat parameter with padding if using a
286            sharded strategy. If using ``NO_SHARD``, then this is the unpadded
287            unsharded flat parameter, and there is no notion of a sharded flat
288            parameter or padded unsharded flat parameter.
289        _full_param_padded (Tensor): Unsharded flat parameter with padding.
290            This is not defined for ``NO_SHARD``. When using mixed precision
291            for parameters, this has the low precision.
292        _full_prec_full_param_padded (Tensor): Full precision unsharded flat
293            parameter with padding. This is used for unsharding outside of
294            computation when using mixed precision for parameters. This is
295            never defined for ``NO_SHARD``.
296        _post_backward_hook_handle (RemovableHandle):
297            Flat parameter's post-backward hook handle. (Compile only)
298        _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
299            Flat parameter's :class:`AccumulateGrad` object and post-backward
300            hook handle. (Eager only)
301        _mp_shard (Tensor): Low precision sharded flat parameter with padding.
302            This is only defined when parameter mixed precision is enabled. For
303            ``NO_SHARD``, this is used for computation.
304        _cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
305            This is only defined when offloading parameters is enabled.
306        _saved_grad_shard (Tensor): Sharded gradient with padding from previous
307            iterations for gradient accumulation without :meth:`no_sync`.
308
309        _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
310            then each original parameter variable; otherwise, ``None``. This
311            does not include any padding tensors.
312        _shared_params (Optional[List[nn.Parameter]]): The original shared
313            parameter variables if ``use_orig_params=True`` and ``None``
314            otherwise.
315        _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
316            views created in the forward and tracked by autograd when
317            ``use_orig_params=True`` and is ``None`` otherwise. This is to
318            preserve those ``Tensor`` variables for the backward to ensure that
319            the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
320            in which case the post-backward hook does not run. This is relevant
321            for cases like reentrant activation checkpointing.
322        _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
323            a mask over the original parameters' gradients indicating if it is
324            logically ``None`` or not; otherwise, ``None``. This does not
325            include entries for padding. This mask is needed because only some
326            of the parameters may have ``None`` gradient, in which case the
327            flat gradient must be non-``None`` and must use zeros to
328            approximate those original ``None`` gradients. This mask informs
329            FSDP to set the original parameter gradients to ``None`` (instead
330            of zeros) as needed.
331    """
332
333    _unpadded_unsharded_size: torch.Size
334    _padded_unsharded_size: torch.Size
335    _sharded_size: torch.Size
336    _num_params: int
337    _param_infos: Tuple[ParamInfo, ...]
338    _shapes: Tuple[torch.Size, ...]
339    _fqns: Tuple[str, ...]
340    _param_extensions: Tuple[Optional[Any], ...]
341    _numels_with_padding: Tuple[int, ...]
342    _numels: Tuple[int, ...]
343    _shard_param_infos: Tuple[_ShardParamInfo, ...]
344    _shared_param_infos: Tuple[SharedParamInfo, ...]
345    _modules: Set[nn.Module]
346    _shard_numel_padded: int
347    _local_shard: Tensor
348    _full_param_padded: Tensor
349    _full_prec_full_param_padded: Tensor
350    # Eager only
351    _post_backward_hook_state: Tuple[Any, Any]
352    # Compile only
353    _post_backward_hook_handle: Any
354    _mp_shard: Tensor
355    _cpu_grad: Tensor
356    _saved_grad_shard: Tensor
357    _params: Optional[List[nn.Parameter]]
358    _shared_params: Optional[List[nn.Parameter]]
359    _tensors: Optional[List[Optional[Tensor]]]
360    _is_grad_none_mask: Optional[List[bool]]
361
362    _is_padding_mask: List[bool]
363
364    def __new__(cls, data=None, requires_grad=True):
365        assert cls is FlatParameter, "subclasses FlatParameter not supported"
366        r = nn.Parameter.__new__(nn.Parameter, data, requires_grad)  # type: ignore[call-arg]
367        r._is_flat_param = True  # type: ignore[attr-defined]
368        return r
369
370    # NB: This is not a regular method, because FlatParameters are not actually
371    # instances of this class (see __new__ above).  So you must indirectly
372    # call this directly through the classmethod.
373    @classmethod
374    def _init_metadata(
375        cls,
376        self,
377        param_infos: List[ParamInfo],
378        numels: List[int],
379        shapes: List[torch.Size],
380        fqns: List[str],
381        shared_param_infos: List[SharedParamInfo],
382        param_extensions: List[Optional[Any]],
383        params: Optional[List[nn.Parameter]],
384        shared_params: Optional[List[nn.Parameter]],
385        is_padding_mask: List[bool],
386    ) -> None:
387        """
388        Initialize attributes holding metadata about the original parameters comprising the flat parameter.
389
390        We expose this method separate from the constructor to keep the
391        constructor only responsible for the flat parameter's tensor data. This
392        method should only be called once per model, while the constructor may
393        be called multiple times, e.g. when reloading from a checkpoint, in
394        which case only the tensor data needs to be passed to the constructor.
395        Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
396        metadata is correctly assumed to be unchanged.
397
398        Args:
399            See the Attributes in the class docstring.
400        """
401        assert len(param_infos) == len(shapes)
402        assert len(param_infos) == len(fqns)
403        assert len(param_infos) == len(param_extensions)
404        self._num_params = len(param_infos)
405        self._param_infos = param_infos
406        self._shapes = shapes
407        self._fqns = fqns
408        self._param_extensions = param_extensions
409        self._is_padding_mask = is_padding_mask
410
411        numels_without_padding: List[int] = []
412        for numel, is_padding in zip(numels, is_padding_mask):
413            if not is_padding:
414                numels_without_padding.append(numel)
415        self._numels = tuple(numels_without_padding)
416        self._numels_with_padding = tuple(numels)
417        assert len(self._numels) == self._num_params
418
419        self._shared_param_infos = tuple(shared_param_infos)
420        self._modules = {pi.module for pi in self._param_infos}.union(
421            {spi.module for spi in self._shared_param_infos}
422        )
423        assert (params is None) == (shared_params is None)
424        if params is not None:
425            assert shared_params is not None and len(shared_params) == len(
426                shared_param_infos
427            )
428            self._params = []
429            for param, is_padding in zip(params, is_padding_mask):
430                if not is_padding:
431                    self._params.append(param)
432            self._shared_params = shared_params
433            # Mark the original parameters to avoid flattening them into
434            # another `FlatParameter` during recursive construction
435            for param in chain(self._params, self._shared_params):
436                _set_fsdp_flattened(param)
437            self._is_grad_none_mask = [False for _ in range(self._num_params)]
438            self._tensors = [None for _ in range(self._num_params)]
439        else:
440            self._params = None
441            self._shared_params = None
442            self._is_grad_none_mask = None
443            self._tensors = None
444        self._unpadded_unsharded_size = self.size()
445        _set_fsdp_flattened(self)
446        # Tracks whether the `FlatParameter`'s post-backward hook has been
447        # called to modify the behavior of the post-backward callback
448        self._post_backward_called = False
449
450
451class FlatParamHandle:
452    """
453    A handle that manages a flat parameter (:class:`FlatParameter`).
454
455    This includes sharding and view management.
456
457    Args:
458        params (Sequence[nn.Parameter]): The parameters to flatten into the
459            flat parameter.
460        fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
461        device (torch.device): The compute and communication device, which
462            should be a non-CPU device. We refer to it as the compute device.
463        sharding_strategy (ShardingStrategy): Sharding strategy to apply to
464            this handle's ``FlatParameter``.
465        offload_params (bool): Whether to offload the handle's
466            ``FlatParameter`` to CPU.
467        mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
468            setting passed to the FSDP constructor.
469        mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
470            precision setting passed to the FSDP constructor.
471        keep_low_precision_grads (bool): Whether to keep gradients in low
472            precision.
473        use_orig_params (bool): If ``True``, then FSDP preserves the original
474            parameter variables and returns them from ``named_parameters()``
475            (e.g. to support different optimizer hyperparameters within one
476            :class:`FlatParameter`). If ``False``, then FSDP reconstructs the
477            parameters every iteration and returns the :class:`FlatParameter` s
478            from ``named_parameters()``.
479    """
480
481    ##################
482    # INITIALIZATION #
483    ##################
484    def __init__(
485        self,
486        params: Sequence[Union[nn.Parameter, Tensor]],
487        fully_sharded_module: nn.Module,
488        device: torch.device,
489        sharding_strategy: HandleShardingStrategy,
490        offload_params: bool,
491        mp_param_dtype: Optional[torch.dtype],
492        mp_reduce_dtype: Optional[torch.dtype],
493        keep_low_precision_grads: bool,
494        process_group: dist.ProcessGroup,
495        use_orig_params: bool,
496        *,
497        fsdp_extension: Optional[FSDPExtensions] = None,
498    ):
499        super().__init__()
500        params = list(params)
501        if len(params) == 0:
502            raise ValueError(
503                f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
504            )
505        self._init_setattr_fns()
506        self._skip_writeback_check = (
507            os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1"
508        )
509        self._use_full_prec_in_eval = (
510            os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
511        )
512        self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1"
513        self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1"
514        if self._skip_writeback_check:
515            _warn_skip_writeback_check(
516                logger,
517                f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check "
518                "for parameter or gradient writeback. Changing parameter or "
519                "gradient storages may lead to silent correctness errors.",
520            )
521        if self._use_fake_all_gather:
522            _warn_use_fake_all_gather(
523                logger,
524                f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute "
525                "all-gather ops. Your training will be incorrect, but "
526                "can reveal how much time spent on all-gather ops.",
527            )
528        if self._use_fake_reduce:
529            _warn_use_fake_reduce(
530                logger,
531                f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute "
532                "reduce-scatter ops. Your training will be incorrect, but "
533                "can reveal how much time spent on reduce-scatter ops.",
534            )
535        # Only align addresses for `use_orig_params=True` (for now)
536        align_addresses = use_orig_params
537        self._init_get_unflat_views_fn(align_addresses)
538        self.device = device
539        self._device_handle = _FSDPDeviceHandle.from_device(self.device)
540        self.process_group = process_group
541        if self._use_fake_all_gather or self._use_fake_reduce:
542            self._fake_process_group = FakeProcessGroup(
543                rank=process_group.rank(), world_size=process_group.size()
544            )
545        self.rank = process_group.rank()
546        self.world_size = process_group.size()
547        self._sharding_strategy = sharding_strategy
548        self._offload_params = offload_params
549        self._use_orig_params = use_orig_params
550        self._keep_low_precision_grads = keep_low_precision_grads
551        self._training_state = HandleTrainingState.IDLE
552        self._debug_level = dist.get_debug_level()
553        self._fully_sharded_module = fully_sharded_module
554        # For strategies that do not free after forward, we skip using sharded
555        # views after forward since the unsharded data exists. We still switch
556        # `self.flat_param` to point to the sharded flat parameter since what
557        # it points to parameterizes behavior. We use the following attribute
558        # to track which tensor data the parameters are unsharded views into.
559        self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None
560        # The index in the state's `all_handles`, which must be the
561        # same across ranks for the execution order validation to work
562        self._handle_index: Optional[int] = None
563        # Index in handles_to_pre_forward_order
564        self._pre_forward_order_index: Optional[int] = None
565        # Index in `handles_post_forward_order`
566        self._post_forward_index: Optional[int] = None
567        # Used for guarding against mistargeted forward prefetches
568        self._needs_pre_forward_unshard = False
569        # Used for guarding against mistargeted backward prefetches
570        self._needs_pre_backward_unshard = False
571        # Was the handle prefetched? Set on successful _prefetch_handle and unshard
572        self._prefetched = False
573        # Optimistically assume a valid input `params` and set dtype attributes
574        # before `_init_flat_param()`, which performs the actual validation
575        self._orig_param_dtype = params[0].dtype
576        self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
577        assert self._fwd_bwd_param_dtype is not None  # mypy
578        self._aligned_numel = (
579            _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
580            if align_addresses
581            else 0
582        )
583        self._fsdp_extension = fsdp_extension
584        self._init_flat_param_and_metadata(
585            params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: ignore[arg-type]
586        )
587        self._use_unsharded_views(as_params=False)
588
589    def _init_setattr_fns(self):
590        use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
591        self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
592        self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
593        if use_unsafe_setattr:
594            self._setattr_tensor = _unsafe_setattr_tensor
595            self._setattr_param = _unsafe_setattr_param
596        else:
597            self._setattr_tensor = _safe_setattr_tensor_or_param
598            self._setattr_param = _safe_setattr_tensor_or_param
599
600    def _init_get_unflat_views_fn(self, align_addresses: bool):
601        self._get_unflat_views = (
602            self._get_unflat_views_aligned
603            if align_addresses
604            else self._get_unflat_views_unaligned
605        )
606
607    def _init_flat_param_and_metadata(
608        self,
609        params: List[Union[Tensor, nn.Parameter]],
610        module: nn.Module,
611        aligned_numel: int,
612        use_orig_params: bool,
613    ) -> None:
614        """
615        Initialize the ``FlatParameter`` and its metadata.
616
617        NOTE: This should only be called once at construction time, after which
618        the ``FlatParameter`` metadata is assumed to be static.
619
620        NOTE: The elements of ``params`` should only be ``Tensor`` s when
621        composing with ``DTensor`` -based tensor parallelism, in which case the
622        elements may be ``DTensor`` local shards.
623        """
624        if len(params) == 0:
625            raise ValueError("Expects non-empty `params`")
626        if aligned_numel < 0:
627            raise ValueError(
628                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
629            )
630        (
631            dtype,
632            flat_param_requires_grad,
633            device,
634        ) = self._validate_tensors_to_flatten(params)
635        params_set = set(params)
636        # For alignment padding, only `numels` gets strictly non-`None`
637        # elements, and all other lists get `None` elements for padding.
638        param_infos: List[ParamInfo] = []
639        numels: List[int] = []
640        shapes: List[torch.Size] = []
641        fqns: List[str] = []
642        shared_param_infos: List[SharedParamInfo] = []
643        shared_param_memo: Dict[
644            Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
645        ] = {}
646        params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
647        shared_params: List[Union[Tensor, nn.Parameter]] = []
648        param_extensions: List[Any] = []
649        is_padding_mask: List[bool] = []
650        total_numel = total_numel_without_padding = 0
651        for submodule_name, submodule in module.named_modules(remove_duplicate=False):
652            for param_name, param in _named_parameters_with_duplicates(
653                submodule, recurse=False
654            ):
655                if param not in params_set:
656                    continue
657                if param in shared_param_memo:  # shared reference
658                    prim_module, prim_module_name, prim_param_name = shared_param_memo[
659                        param
660                    ]
661                    shared_params.append(param)
662                    shared_param_infos.append(
663                        SharedParamInfo(
664                            param_name,
665                            submodule,
666                            submodule_name,
667                            prim_param_name,
668                            prim_module,
669                            prim_module_name,
670                        )
671                    )
672                else:
673                    if aligned_numel > 0:
674                        numel_to_pad = aligned_numel - (total_numel % aligned_numel)
675                        if numel_to_pad > 0 and numel_to_pad < aligned_numel:
676                            padding_tensor = _construct_padding_tensor(
677                                numel_to_pad, dtype, False, device
678                            )
679                            params_to_flatten.append(padding_tensor)
680                            is_padding_mask.append(True)
681                            numels.append(numel_to_pad)
682                            total_numel += numel_to_pad
683                    transform_t, extension = _ext_pre_flatten_transform(
684                        param,
685                        self._fsdp_extension,
686                    )
687                    param = cast(nn.Parameter, transform_t)
688                    param_extensions.append(extension)
689                    shared_param_memo[param] = (submodule, submodule_name, param_name)
690                    params_to_flatten.append(param)
691                    is_padding_mask.append(False)
692                    param_infos.append(ParamInfo(param_name, submodule, submodule_name))
693                    numels.append(param.numel())
694                    shapes.append(param.shape)
695                    fqn = (
696                        submodule_name + "." + param_name
697                        if submodule_name
698                        else param_name
699                    )
700                    fqns.append(fqn)
701                    total_numel += param.numel()
702                    total_numel_without_padding += param.numel()
703        if len(params_to_flatten) == 0:
704            raise ValueError(
705                f"`params` were not found in `module`'s tree"
706                f"params: {params}\nmodule: {module}"
707            )
708        if (
709            self.rank == 0
710            and aligned_numel > 0
711            and total_numel != total_numel_without_padding
712        ):
713            logger.debug(
714                "FSDP FlatParameter address alignment created "
715                "%s numel of padding (%s vs. %s)",
716                total_numel - total_numel_without_padding,
717                total_numel,
718                total_numel_without_padding,
719            )
720        if aligned_numel > 0:
721            # Pad to be divisible by world size to avoid a copy for the
722            # post-backward reduce-scatter
723            numel_to_pad = self.world_size - (total_numel % self.world_size)
724            if numel_to_pad > 0 and numel_to_pad < self.world_size:
725                if self.rank == 0:
726                    logger.info(
727                        "FSDP FlatParameter world size divisibility created "
728                        "%s numel of padding",
729                        numel_to_pad,
730                    )
731                padding_tensor = _construct_padding_tensor(
732                    numel_to_pad, dtype, False, device
733                )
734                params_to_flatten.append(padding_tensor)
735                is_padding_mask.append(True)
736                numels.append(numel_to_pad)
737                total_numel += numel_to_pad
738        # Pass `aligned_numel=0` since we already included padding tensors
739        self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
740            params_to_flatten,
741            aligned_numel=0,
742            requires_grad=flat_param_requires_grad,
743        )
744        FlatParameter._init_metadata(
745            self.flat_param,
746            param_infos,
747            numels,
748            shapes,
749            fqns,
750            shared_param_infos,
751            param_extensions,
752            _convert_to_params(params_to_flatten) if use_orig_params else None,
753            _convert_to_params(shared_params) if use_orig_params else None,
754            is_padding_mask,
755        )
756
757    def _validate_tensors_to_flatten(
758        self, tensors: List[Union[Tensor, nn.Parameter]]
759    ) -> Tuple:
760        """Validate the tensors to flatten and returns any necessary metadata."""
761        dtype: Optional[torch.dtype] = None
762        # Return as the logical OR over each tensor's value
763        flat_param_requires_grad: Optional[bool] = None
764        device: Optional[torch.device] = None
765        # For `use_orig_params=True`, permit non-uniform `requires_grad`
766        for tensor in tensors:
767            if isinstance(tensor, FlatParameter):
768                raise ValueError("Cannot flatten a `FlatParameter`")
769            if dtype is None and not tensor.is_floating_point():
770                raise ValueError("Cannot flatten integer dtype tensors")
771            if dtype is not None and tensor.dtype != dtype:
772                raise ValueError(
773                    f"Must flatten tensors with uniform dtype but got {dtype} "
774                    f"and {tensor.dtype}"
775                )
776            if (
777                not self._use_orig_params
778                and flat_param_requires_grad is not None
779                and tensor.requires_grad != flat_param_requires_grad
780            ):
781                raise ValueError(
782                    "Must flatten tensors with uniform `requires_grad` when "
783                    "`use_orig_params=False`"
784                )
785            if device is not None and tensor.device != device:
786                raise ValueError(
787                    "Must flatten tensors on the same device but got both "
788                    f"{device} and {tensor.device}"
789                )
790            dtype = tensor.dtype
791            flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
792            device = tensor.device
793        assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
794        return dtype, flat_param_requires_grad, device
795
796    def flatten_tensors(
797        self,
798        tensors: List[Tensor],
799        aligned_numel: int,
800    ) -> Tensor:
801        """
802        Flatten ``tensors`` into a single flat tensor.
803
804        The flattening optionally includes
805        padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
806        gives the numel required to have address alignment.
807
808        NOTE: The padding alignment algorithm must be kept in sync with
809        :meth:`_init_flat_param_metadata`. We separate the two methods because
810        the initialization happens once, whereas this method may be called
811        multiple times throughout training (e.g. for checkpointing).
812        """
813        if len(tensors) == 0:
814            raise ValueError("Expects non-empty `tensors`")
815        if aligned_numel < 0:
816            raise ValueError(
817                f"Expects non-negative `aligned_numel` but got {aligned_numel}"
818            )
819        dtype, _, device = self._validate_tensors_to_flatten(tensors)
820        flat_tensors: List[Tensor] = []
821        if aligned_numel > 0:
822            total_numel = 0
823            for tensor in tensors:
824                numel_to_pad = aligned_numel - (total_numel % aligned_numel)
825                if numel_to_pad > 0 and numel_to_pad < aligned_numel:
826                    padding_tensor = _construct_padding_tensor(
827                        numel_to_pad, dtype, False, device
828                    )
829                    flat_tensors.append(padding_tensor)
830                    total_numel += numel_to_pad
831                flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
832                total_numel += tensor.numel()
833            numel_to_pad = self.world_size - (total_numel % self.world_size)
834            if numel_to_pad > 0 and numel_to_pad < self.world_size:
835                padding_tensor = _construct_padding_tensor(
836                    numel_to_pad, dtype, False, device
837                )
838                flat_tensors.append(padding_tensor)
839                total_numel += numel_to_pad
840        else:
841            flat_tensors = [
842                torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
843            ]
844        return torch.cat(flat_tensors, dim=0)
845
846    def flatten_tensors_into_flat_param(
847        self,
848        tensors: List[Tensor],
849        aligned_numel: int,
850        requires_grad: bool,
851    ) -> FlatParameter:
852        flat_param_data = self.flatten_tensors(tensors, aligned_numel)
853        return FlatParameter(flat_param_data, requires_grad=requires_grad)
854
855    def _init_param_reduce_dtypes(
856        self,
857        mp_param_dtype: Optional[torch.dtype],
858        mp_reduce_dtype: Optional[torch.dtype],
859    ) -> None:
860        """
861        Initialize param and reduce dtypes.
862
863        Precondition: ``self.flat_param`` is set. This ensures that this
864        handle's parameters have a single dtype.
865
866        Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
867        ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
868        is ``None``, then we assume the original parameter dtype. One special
869        case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
870        is ``None``, in which case we assume the gradient reduction dtype
871        matches the forward/backward parameter dtype.
872        """
873        # Save whether these dtypes were specified so that we permit the
874        # parameter dtype to change up until the lazy initialization
875        self._low_prec_param_dtype_specified = mp_param_dtype is not None
876        self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
877        if (
878            self._low_prec_param_dtype_specified
879            and not self._low_prec_reduce_dtype_specified
880        ):
881            # Special case: infer gradient reduction mixed precision
882            self._fwd_bwd_param_dtype = mp_param_dtype
883            self._reduce_dtype = self._fwd_bwd_param_dtype
884        else:
885            self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
886            self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
887        assert self._fwd_bwd_param_dtype is not None
888        assert self._reduce_dtype is not None
889
890    ###################################
891    # SHARD INITIALIZATION & METADATA #
892    ###################################
893    @torch.no_grad()
894    def shard(self):
895        """
896        Shard the handle's ``FlatParameter``.
897
898        This allocates new memory for
899        the sharded flat parameter and frees the unsharded flat parameter's
900        storage.
901
902        Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
903        metadata attributes are set for all sharding strategies.
904        """
905        flat_param = self.flat_param
906        if not self.uses_sharded_strategy:
907            self._init_shard_metadata(0, 0, flat_param.numel() - 1)
908        else:
909            _p_assert(
910                flat_param.storage_offset() == 0,
911                "The `FlatParameter` is not the sole occupant of its storage",
912            )
913            sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
914                flat_param, self.rank, self.world_size
915            )
916            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
917                allocated = flat_param._typed_storage()._size() > 0
918                if allocated:
919                    flat_param._typed_storage()._resize_(0)
920            flat_param.set_(sharded_flat_param)  # type: ignore[call-overload]
921            start_idx = sharded_flat_param.numel() * self.rank
922            end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1  # inclusive
923            self._init_shard_metadata(numel_padded, start_idx, end_idx)
924        if self._use_orig_params:
925            self._use_sharded_views()
926
927    def _init_shard_metadata(
928        self,
929        numel_padded: int,
930        unsharded_start_idx: int,
931        unsharded_end_idx: int,
932    ) -> None:
933        """
934        Initialize shard-related metadata for this rank's shard of the flat parameter.
935
936        This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``.
937
938        Args:
939            numel_padded (int): Numel padded for this rank's sharded flat
940                parameter.
941            unsharded_start_idx (int): Start index in the unsharded flat
942            parameter assigned to this rank.
943            unsharded_end_idx (int): End index (inclusive) in the unsharded
944                flat parameter assigned to this rank.
945
946        Precondition: ``self.flat_param`` 's data is the sharded flat
947        parameter.
948        """
949        flat_param = self.flat_param
950        flat_param._sharded_size = flat_param.size()  # type: ignore[attr-defined]
951        sharded_flat_param_numel = flat_param.numel()  # includes `numel_padded`
952        _p_assert(
953            unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
954            f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
955        )
956        _p_assert(
957            numel_padded <= sharded_flat_param_numel,
958            f"numel_padded: {numel_padded} "
959            f"sharded_flat_param_numel: {sharded_flat_param_numel}",
960        )
961        shard_param_infos = self._get_shard_metadata(
962            unsharded_start_idx, unsharded_end_idx
963        )
964        assert (
965            len(shard_param_infos) == flat_param._num_params
966        ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
967        flat_param._shard_param_infos = shard_param_infos  # type: ignore[attr-defined]
968        flat_param._shard_numel_padded = numel_padded  # type: ignore[attr-defined]
969
970    def _get_shard_metadata(
971        self,
972        unsharded_start_idx: int,
973        unsharded_end_idx: int,
974    ) -> Tuple[_ShardParamInfo, ...]:
975        """
976        Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive).
977
978        ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the
979        unsharded flat parameter specifying the shard.
980        """
981        flat_param_offsets = self._get_flat_param_offsets()
982        assert len(flat_param_offsets) == len(
983            self.flat_param._numels_with_padding
984        ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
985        shard_param_infos: List[_ShardParamInfo] = []
986        sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
987        # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
988        # into the unsharded flat parameter (inclusive) of the given parameter
989        for i, (
990            (unsharded_param_start_idx, unsharded_param_end_idx),
991            is_padding,
992        ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
993            if is_padding:
994                continue
995            in_sharded_flat_param = (
996                unsharded_start_idx <= unsharded_param_end_idx
997                and unsharded_end_idx >= unsharded_param_start_idx
998            )
999            if not in_sharded_flat_param:
1000                shard_param_info = _ShardParamInfo(False, None, None, None, None)
1001            else:
1002                if unsharded_start_idx <= unsharded_param_start_idx:
1003                    # This branch can only happen once since the rank's
1004                    # unsharded start index can only intersect one parameter
1005                    intra_param_start_idx = 0
1006                    offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
1007                else:
1008                    intra_param_start_idx = (
1009                        unsharded_start_idx - unsharded_param_start_idx
1010                    )
1011                    offset_in_shard = 0
1012                assert (
1013                    offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
1014                ), (
1015                    f"Invalid `offset_in_shard` of {offset_in_shard} for "
1016                    f"sharded flat parameter with {sharded_flat_param_numel} numel"
1017                )
1018                intra_param_end_idx = (
1019                    min(unsharded_param_end_idx, unsharded_end_idx)
1020                    - unsharded_param_start_idx
1021                )
1022                numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
1023                shard_param_info = _ShardParamInfo(
1024                    True,
1025                    offset_in_shard,
1026                    numel_in_shard,
1027                    intra_param_start_idx,
1028                    intra_param_end_idx,
1029                )
1030            shard_param_infos.append(shard_param_info)
1031        return tuple(shard_param_infos)
1032
1033    @staticmethod
1034    def _get_unpadded_shard(
1035        tensor: Tensor,
1036        rank: int,
1037        world_size: int,
1038    ) -> Tuple[Tensor, int]:
1039        """
1040        Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``.
1041
1042        The returned value is a tuple of the shard of ``tensor`` without any
1043        padding and the numel to pad for that shard.
1044
1045        If ``tensor`` is already flattened or may be viewed in the flattened
1046        shape (which is true in the expected usage), then this method does not
1047        allocate any new tensor memory.
1048        """
1049        chunks = torch.flatten(tensor).chunk(world_size)
1050        if len(chunks) < (rank + 1):
1051            # This rank gets an empty chunk fully padded with zeros since there
1052            # are not enough chunks across ranks
1053            chunk = chunks[0].new_empty(0)
1054        else:
1055            chunk = chunks[rank]
1056        numel_to_pad = chunks[0].numel() - chunk.numel()
1057        assert (
1058            numel_to_pad >= 0
1059        ), "Chunk's size should be at most the first chunk's size"
1060        return chunk, numel_to_pad
1061
1062    @staticmethod
1063    def _get_shard(
1064        tensor: Tensor,
1065        rank: int,
1066        world_size: int,
1067    ) -> Tuple[Tensor, int]:
1068        """
1069        Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard.
1070
1071        This method allocates new memory (via :meth:`clone`) since the
1072        unsharded ``tensor`` may be deallocated after this method returns.
1073        """
1074        chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1075            tensor, rank, world_size
1076        )
1077        shard = chunk.clone()
1078        if numel_to_pad > 0:
1079            shard = F.pad(shard, [0, numel_to_pad])
1080        return shard, numel_to_pad
1081
1082    @staticmethod
1083    def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
1084        """
1085        Return the shape of ``tensor`` after sharding including padding.
1086
1087        This requires ``tensor`` to have 1D shape and ensures that the returned
1088        shape is 1D.
1089        """
1090        assert len(tensor.shape) == 1, f"{tensor.shape}"
1091        unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
1092            tensor, rank, world_size
1093        )
1094        unpadded_sharded_size = unpadded_sharded_tensor.size()
1095        assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
1096        return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
1097
1098    def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
1099        """
1100        Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding).
1101
1102        NOTE: The returned list includes elements for alignment padding.
1103        """
1104        cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
1105        starts = [0] + cumulative_sum[:-1]
1106        ends = [end - 1 for end in cumulative_sum]  # inclusive
1107        param_offsets = list(zip(starts, ends))
1108        return param_offsets
1109
1110    @no_type_check
1111    def shard_metadata(
1112        self,
1113    ) -> FlatParamShardMetadata:
1114        """
1115        Return the shard-related metadata specific to this rank's shard of the flat parameter.
1116
1117        NOTE: The returned tuple does not include elements for alignment
1118        padding but does account for the padding.
1119        """
1120        fqns_list = []
1121        shapes_list = []
1122        numels_list = []
1123        shard_param_offsets = []
1124        for fqn, shape, numel, shard_param_info in zip(
1125            self.flat_param._fqns,
1126            self.flat_param._shapes,
1127            self.flat_param._numels,
1128            self.flat_param._shard_param_infos,
1129        ):
1130            if not shard_param_info.in_shard:
1131                continue
1132            fqns_list.append(fqn)
1133            shapes_list.append(shape)
1134            numels_list.append(numel)
1135            shard_param_offsets.append(
1136                (
1137                    shard_param_info.intra_param_start_idx,
1138                    shard_param_info.intra_param_end_idx,
1139                )
1140            )
1141        return FlatParamShardMetadata(
1142            tuple(fqns_list),
1143            tuple(shapes_list),
1144            tuple(numels_list),
1145            tuple(shard_param_offsets),
1146        )
1147
1148    @no_type_check
1149    @torch.no_grad()
1150    def init_flat_param_attributes(self) -> None:
1151        """
1152        This initializes some attributes on the handle's ``FlatParameter``.
1153        This should be called during lazy initialization since it requires the
1154        parameter to be on the compute device if not offloading to CPU and we
1155        want to give users the chance to move the parameter appropriately after
1156        the FSDP constructor.
1157
1158        For each tensor attribute on the ``FlatParameter``, see the unshard and
1159        reshard methods in this class for the allocation and free pattern.
1160        """
1161        flat_param = self.flat_param
1162        if flat_param.dtype != self._orig_param_dtype:
1163            # Entering this branch means that the user changed the parameter
1164            # dtype after FSDP initialization, in which case we may need to
1165            # refresh some saved dtype attributes (dtypes specified as a part
1166            # of mixed precision take precedence).
1167            if not self._low_prec_param_dtype_specified:
1168                self._fwd_bwd_param_dtype = flat_param.dtype
1169            # For `reduce_dtype`, require `param_dtype` was not specified since
1170            # then we infer the `reduce_dtype` from the specified `param_dtype`
1171            if (
1172                not self._low_prec_reduce_dtype_specified
1173                and not self._low_prec_param_dtype_specified
1174            ):
1175                self._reduce_dtype = flat_param.dtype
1176            self._orig_param_dtype = flat_param.dtype
1177        cpu_device = torch.device("cpu")
1178        if self._offload_params:
1179            _p_assert(
1180                flat_param.device == cpu_device,
1181                f"Expects the `FlatParameter` to be on CPU when parameter CPU "
1182                f"offloading is enabled, not {flat_param.device}",
1183            )
1184        else:
1185            self._check_on_compute_device(self.flat_param)
1186        flat_param._local_shard = flat_param.data
1187        if self._offload_params:
1188            # Pin the memory for faster H2D transfer
1189            flat_param._local_shard = flat_param._local_shard.pin_memory(
1190                device=self.device
1191            )
1192            # Pre-allocate the sharded gradient on CPU to enable non-blocking
1193            # D2H transfer during the backward pass
1194            flat_param._cpu_grad = torch.zeros_like(
1195                flat_param._local_shard, device=cpu_device
1196            ).pin_memory(device=self.device)
1197        if self._uses_param_mixed_precision:
1198            # For parameter mixed precision, we maintain a low precision
1199            # sharded tensor on the compute device to be all-gathered (for
1200            # sharded strategies) or directly used (for `NO_SHARD`) for
1201            # computation.
1202            flat_param._mp_shard = torch.empty_like(
1203                flat_param._local_shard,
1204                device=self.device,
1205                dtype=self._fwd_bwd_param_dtype,
1206            )
1207            _free_storage(flat_param._mp_shard)
1208        if self.uses_sharded_strategy:
1209            # We maintain a padded unsharded tensor that serves as the
1210            # all-gather destination and owns the original parameter storages.
1211            unsharded_param_dtype = (
1212                self._fwd_bwd_param_dtype
1213                if self._uses_param_mixed_precision
1214                else flat_param.dtype
1215            )  # use low precision if parameter mixed precision is enabled
1216            padded_unsharded_numel = flat_param.numel() * self.world_size
1217            flat_param._full_param_padded = torch.empty(
1218                padded_unsharded_numel,
1219                device=self.device,
1220                dtype=unsharded_param_dtype,
1221            )
1222            flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
1223            _free_storage(flat_param._full_param_padded)
1224
1225            if self._uses_param_mixed_precision:
1226                # For parameter mixed precision, we maintain a full precision
1227                # padded unsharded tensor for when we force full precision.
1228                flat_param._full_prec_full_param_padded = torch.empty(
1229                    padded_unsharded_numel,
1230                    device=self.device,
1231                    dtype=flat_param.dtype,  # full precision
1232                )
1233                _free_storage(flat_param._full_prec_full_param_padded)
1234
1235    ###################
1236    # UNSHARD/RESHARD #
1237    ###################
1238    def pre_unshard(self) -> bool:
1239        """
1240        Return ``False`` if this is a no-op and ``True`` otherwise.
1241
1242        Postcondition: ``self.flat_param`` 's data is on the device for
1243        communication and is what should be all-gathered. This means that it
1244        matches the dtype of the expected unsharded parameter.
1245        """
1246        if (
1247            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
1248            and self._skipped_use_sharded_views
1249        ):
1250            # Since this path imposes special semantics for the unsharded flat
1251            # parameter (e.g. forcing full precision), use sharded views to
1252            # reuse the existing logic for that special handling
1253            self._use_sharded_views()
1254        ret = False
1255        if self._use_orig_params and not self._skip_writeback_check:
1256            ret = self._writeback_orig_params()
1257        if (
1258            self.uses_sharded_strategy
1259            and not self._offload_params
1260            and not self.needs_unshard()
1261        ):
1262            pass  # no-op
1263        elif self._uses_param_mixed_precision and not self._force_full_precision:
1264            self._use_low_precision_shard()
1265            ret = True
1266        elif self._offload_params and self.flat_param.device != self.device:
1267            # NOTE: This creates a new tensor distinct from any attributes.
1268            self.flat_param_to(self.device, non_blocking=True)
1269            ret = True
1270        self._check_on_compute_device(self.flat_param)
1271        return ret
1272
1273    def _use_low_precision_shard(self):
1274        """Allocate on the compute device and switch to using the low precision sharded flat parameter."""
1275        self._check_low_precision_shard()
1276        flat_param = self.flat_param
1277        _alloc_storage(
1278            flat_param._mp_shard, flat_param._local_shard.size()  # type: ignore[attr-defined]
1279        )
1280        # `copy_()` implicitly casts to the low precision
1281        flat_param._mp_shard.copy_(  # type: ignore[attr-defined]
1282            flat_param._local_shard.to(  # type: ignore[attr-defined]
1283                self.device, non_blocking=True
1284            )
1285        )
1286        # Invariant: `_mp_shard` is always on the compute device.
1287        flat_param.data = flat_param._mp_shard  # type: ignore[attr-defined]
1288
1289    def unshard(self):
1290        """
1291        Run the unshard logic.
1292
1293        This includes all-gathering the flat parameter
1294        and switching to using the unsharded flat parameter. If the handle does
1295        not need unsharding, then this only switches to using the unsharded
1296        flat parameter. For ``NO_SHARD``, this is a no-op.
1297
1298        If FSDP is in :meth:`summon_full_params` and the handle uses parameter
1299        mixed precision, then the parameter is forced to full precision.
1300        """
1301        if not self.needs_unshard():
1302            # Even when not needing an unshard, we should switch to using
1303            # the unsharded flat parameter
1304            unsharded_flat_param = (
1305                self._get_padded_unsharded_flat_param()
1306                if self.uses_sharded_strategy
1307                else self.flat_param
1308            )
1309            self._use_unsharded_flat_param(unsharded_flat_param)
1310            return
1311        unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1312        padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
1313        self._use_unsharded_flat_param(padded_unsharded_flat_param)
1314
1315    def needs_unshard(self) -> bool:
1316        """Return if the handle's flat parameter needs to be unsharded."""
1317        if not self.uses_sharded_strategy:
1318            return False
1319        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1320        already_unsharded = _same_storage_size(
1321            unsharded_flat_param, unsharded_flat_param.numel()
1322        )
1323        return not already_unsharded
1324
1325    def _alloc_padded_unsharded_flat_param(self):
1326        """
1327        Allocate the *padded* unsharded flat parameter.
1328
1329        The unpadded unsharded
1330        flat parameter is always a view into the padded one. This padded
1331        parameter is saved to a different attribute on the ``FlatParameter``
1332        depending on if we force full precision.
1333        """
1334        self._check_sharded_strategy()
1335        flat_param = self.flat_param
1336        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1337        self._check_storage_freed(unsharded_flat_param)
1338        _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size)  # type: ignore[attr-defined]
1339        return unsharded_flat_param
1340
1341    def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
1342        """
1343        Return a reference to the padded unsharded flat parameter depending on the calling context.
1344
1345        This should only be called if using a sharded strategy.
1346        """
1347        self._check_sharded_strategy()
1348        flat_param = self.flat_param
1349        if self._force_full_precision and self._uses_param_mixed_precision:
1350            # When parameter mixed precision is enabled, we use a different
1351            # tensor as the all-gather destination to preserve the invariant
1352            # that  `_full_param_padded` is in the low precision
1353            unsharded_flat_param = flat_param._full_prec_full_param_padded  # type: ignore[attr-defined]
1354            _p_assert(
1355                unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
1356                f"Expects full precision but got {self._fwd_bwd_param_dtype}",
1357            )
1358            # For no-reshard-after-forward strategies, `_full_param_padded` may
1359            # still be allocated from a previous forward. As we are forcing
1360            # full precision here, the full-precision unsharded copy may be
1361            # modified, invalidating the existing low-precision unsharded copy,
1362            # so we should free it here to ensure a new all-gather for the next
1363            # forward/backward computation to persist the modifications.
1364            if flat_param._full_param_padded.untyped_storage().size() > 0:
1365                _free_storage(flat_param._full_param_padded)
1366        else:
1367            unsharded_flat_param = flat_param._full_param_padded  # type: ignore[attr-defined]
1368        return unsharded_flat_param
1369
1370    def _all_gather_flat_param(
1371        self,
1372        padded_unsharded_flat_param: Tensor,
1373    ) -> Tensor:
1374        """
1375        All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``.
1376
1377        Then switch to use the all-gathered tensor.
1378        """
1379        _p_assert(
1380            hasattr(self, "process_group") and hasattr(self, "world_size"),
1381            "Expects a process group and world size to have been set via `shard()`",
1382        )
1383        sharded_flat_param = self.flat_param.data
1384        expected_numel = sharded_flat_param.numel() * self.world_size
1385        _p_assert(
1386            padded_unsharded_flat_param.numel() == expected_numel,
1387            f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
1388        )
1389
1390        pg = (
1391            self._fake_process_group
1392            if self._use_fake_all_gather
1393            else self.process_group
1394        )
1395
1396        # HACK this should be handled by C10D
1397        if sharded_flat_param.is_cpu:  # type: ignore[attr-defined]
1398            tensor_list = list(
1399                torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))
1400            )
1401            dist.all_gather(tensor_list, sharded_flat_param, group=pg)
1402        else:
1403            dist.all_gather_into_tensor(
1404                padded_unsharded_flat_param,
1405                sharded_flat_param,
1406                pg,
1407            )
1408
1409        if self._offload_params:
1410            # In case of offloading, `flat_param.data` (i.e. sharded param) is
1411            # created on the pre-unshard stream. We need to hand it over to the
1412            # unshard stream for all-gather
1413            _no_dispatch_record_stream(
1414                sharded_flat_param,
1415                self._device_handle.current_stream(),  # unshard_stream
1416            )
1417        return padded_unsharded_flat_param
1418
1419    def _use_unsharded_flat_param(
1420        self,
1421        padded_unsharded_flat_param: torch.Tensor,
1422    ) -> None:
1423        """
1424        Switch to use the *unpadded* unsharded flat parameter.
1425
1426        This is a view into the *padded* unsharded flat parameter.
1427        """
1428        unsharded_size = self.flat_param._unpadded_unsharded_size
1429        flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()]
1430        # slicing [:] is not visible to autograd because of .data
1431        self.flat_param.data = flat_param_part
1432        in_forward = self._training_state == HandleTrainingState.FORWARD
1433        in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
1434        if self._use_orig_params:
1435            if self._skipped_use_sharded_views and in_pre_backward:
1436                # This call corresponds to the complementary pre-backward
1437                # `_use_unsharded_views()` to the skipped pre-forward
1438                # `_use_sharded_views()`, so we should skip this one too.
1439                return
1440            # We use `Tensor` views in the forward so that they are tracked by
1441            # autograd. We use them in the pre-backward as well to support
1442            # reentrant activation checkpointing, which needs the views to be
1443            # tracked by autograd in the backward pass's recomputed forward.
1444            self._use_unsharded_views(
1445                as_params=(not in_forward and not in_pre_backward)
1446            )
1447        elif in_forward:
1448            self._use_unsharded_views(as_params=False)
1449
1450    def post_unshard(self):
1451        """
1452        Run the post-unshard logic.
1453
1454        This includes freeing the low precision shard if needed.
1455        """
1456        if self._uses_param_mixed_precision and self.uses_sharded_strategy:
1457            self._free_low_precision_sharded_param()
1458        self._check_on_compute_device(self.flat_param)
1459
1460    def _free_low_precision_sharded_param(self):
1461        """Frees the low precision sharded flat parameter."""
1462        self._check_low_precision_shard()
1463        # `_mp_shard` is allocated in the pre-unshard stream, consumed in the
1464        # unshard stream for sharded strategies, and consumed in both the
1465        # unshard and default streams for `NO_SHARD`. For sharded strategies,
1466        # the current stream here is the unshard stream, and for `NO_SHARD`,
1467        # it is the default stream. For `NO_SHARD`, only recording for the
1468        # default stream suffices since the default stream waits for the
1469        # unshard stream.
1470        _no_dispatch_record_stream(
1471            self.flat_param._mp_shard, self._device_handle.current_stream()  # type: ignore[attr-defined]
1472        )
1473        _free_storage(self.flat_param._mp_shard)  # type: ignore[attr-defined]
1474
1475    @torch.no_grad()
1476    def unshard_grad(self):
1477        """
1478        Unshard the handle's ``FlatParameter``'s gradient.
1479
1480        If all ranks have
1481        ``None`` gradient, then all original parameters will as well. This
1482        method performs an all-reduce and an all-gather. The additional
1483        all-reduce is tolerable since this method is not meant to be used on
1484        the computation critical path.
1485
1486        Postcondition: ``_saved_grad_shard`` is defined and contains the value
1487        to set ``flat_param.grad`` after gradients are resharded.
1488        """
1489        if not self.uses_sharded_strategy:
1490            self._use_unsharded_grad_views()
1491            return
1492        flat_param = self.flat_param
1493        self._check_unsharded(flat_param)
1494
1495        # Check if all ranks have a `None` gradient
1496        num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
1497        num_grad_none[0] = flat_param.grad is None
1498        dist.all_reduce(num_grad_none, group=self.process_group)
1499        if num_grad_none[0] == self.world_size:
1500            flat_param._saved_grad_shard = None  # type: ignore[assignment]
1501            self._use_unsharded_grad_views()
1502            return
1503
1504        if flat_param.grad is None:
1505            # In the case that only some ranks have `None` gradient, we use
1506            # zeros to approximate as a best effort attempt
1507            if self._debug_level == dist.DebugLevel.INFO:
1508                warnings.warn(
1509                    f"[Rank {self.rank}] Only some but not all ranks have a "
1510                    "`None` `FlatParameter` gradient, so FSDP is using zeros to "
1511                    "approximate those ranks' sharded gradients being `None`"
1512                )
1513            flat_param._saved_grad_shard = None  # type: ignore[assignment]
1514            sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device)  # type: ignore[attr-defined]
1515        else:
1516            self._check_sharded(flat_param.grad)
1517            flat_param._saved_grad_shard = flat_param.grad  # type: ignore[attr-defined]
1518            sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1519        padded_unsharded_grad = torch.empty(
1520            flat_param._padded_unsharded_size,  # type: ignore[attr-defined]
1521            device=self.device,
1522            dtype=sharded_grad.dtype,
1523        )
1524        dist.all_gather_into_tensor(
1525            padded_unsharded_grad, sharded_grad, self.process_group
1526        )
1527        unsharded_size = self.flat_param._unpadded_unsharded_size
1528        flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
1529            unsharded_size
1530        )
1531        self._use_unsharded_grad_views()
1532
1533    def reshard_grad(self):
1534        if self._use_orig_params:
1535            self._use_sharded_grad_views()
1536        if not self.uses_sharded_strategy:
1537            return
1538        self.flat_param.grad = self.flat_param._saved_grad_shard  # type: ignore[attr-defined]
1539        delattr(self.flat_param, "_saved_grad_shard")
1540
1541    def prepare_gradient_for_backward(self):
1542        """
1543        Prepare the gradient for the backward computation.
1544
1545        This is done by saving and clearing any existing sharded gradient
1546        in ``.grad`` to enable computing a new unsharded gradient.
1547        """
1548        _p_assert(
1549            self._training_state
1550            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
1551            "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
1552        )
1553        flat_param = self.flat_param
1554        if flat_param.grad is not None and (
1555            flat_param.grad.size() != flat_param._unpadded_unsharded_size
1556            or flat_param.grad.device != flat_param.device  # grad on CPU
1557        ):
1558            self._check_on_compute_device(self.flat_param)
1559            grad_offloaded = flat_param.grad.device != self.device
1560            _p_assert(
1561                not grad_offloaded or self._offload_params,
1562                f"Expects the sharded gradient to be on {self.device} "
1563                f"but got {flat_param.grad.device}",
1564            )
1565            prev_iter_synced_gradients = (
1566                flat_param.grad.size()
1567                == flat_param._local_shard.size()  # type: ignore[attr-defined]
1568            )
1569            if prev_iter_synced_gradients:
1570                # TODO (awgu): Gradient accumulation outside `no_sync()`
1571                # does not work with CPU offloading. The issue should be
1572                # that, in the post-backward hook, we cannot do an addition
1573                # between a CPU tensor (the existing sharded gradient) and
1574                # a GPU tensor (the new sharded gradient).
1575                if not grad_offloaded:
1576                    flat_param._saved_grad_shard = flat_param.grad.data  # type: ignore[attr-defined]
1577                    sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1578                else:
1579                    _p_assert(
1580                        hasattr(flat_param, "_cpu_grad"),
1581                        "`_cpu_grad` should be defined if the gradient is on CPU",
1582                    )
1583                    sharded_grad = flat_param._cpu_grad  # type: ignore[attr-defined]
1584                # If user specified to keep the gradient in low precision, then
1585                # the gradient may still be of the low precision dtype if the
1586                # user did not set the gradient to `None` after the previous
1587                # backward, in which case FSDP should cast back to the full
1588                # precision dtype so that FSDP can accumulate in that dtype in
1589                # the post-backward hook and assign to `.grad` in that dtype in
1590                # the post-backward callback.
1591                local_shard_dtype = flat_param._local_shard.dtype  # type: ignore[attr-defined]
1592                if (
1593                    self._keep_low_precision_grads
1594                    and sharded_grad.dtype != local_shard_dtype
1595                ):
1596                    sharded_grad.data = sharded_grad.to(local_shard_dtype)
1597            else:
1598                padded_unsharded_size = flat_param._padded_unsharded_size  # type: ignore[attr-defined]
1599                _p_assert(
1600                    flat_param.grad.size() == padded_unsharded_size,
1601                    "Expects `.grad` to be the unsharded gradient in "
1602                    f"`no_sync()` with size {padded_unsharded_size} "
1603                    f"but got size {flat_param.grad.size()}",
1604                )
1605            flat_param.grad = None
1606
1607    def prepare_gradient_for_optim(self):
1608        """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute."""
1609
1610        def cast_grad_to_param_dtype_if_needed(flat_param):
1611            # TODO (rohan-varma): test for full precision with keep_low_precision_grads
1612            if not self._force_full_precision and self._keep_low_precision_grads:
1613                _p_assert(flat_param.grad is not None, "Unexpected None grad!")
1614                if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
1615                    flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
1616                    if self._use_orig_params:
1617                        self._use_sharded_grad_views()
1618
1619        flat_param = self.flat_param
1620        # TODO (awgu): We should replace these conditional checks to encode
1621        # the logical intention more directly.
1622        if hasattr(flat_param, "_cpu_grad"):
1623            # NOTE: This branch includes `NO_SHARD`.
1624            self._check_sharded(flat_param)
1625            self._check_on_cpu(flat_param)
1626            flat_param.grad = flat_param._cpu_grad  # type: ignore[attr-defined]
1627            cast_grad_to_param_dtype_if_needed(flat_param)
1628        elif hasattr(flat_param, "_saved_grad_shard"):
1629            self._check_sharded(flat_param)
1630            self._check_on_compute_device(flat_param)
1631            if flat_param._saved_grad_shard is not None:
1632                self._check_on_compute_device(flat_param._saved_grad_shard)  # type: ignore[attr-defined]
1633            # If no sharded gradient was computed this iteration, then there is
1634            # no need to forward `_saved_grad_shard` to `grad`
1635            if flat_param._post_backward_called:  # type: ignore[attr-defined]
1636                flat_param.grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
1637                if flat_param.grad is not None:
1638                    cast_grad_to_param_dtype_if_needed(flat_param)
1639        else:
1640            _p_assert(
1641                not self.uses_sharded_strategy
1642                or not flat_param._post_backward_called,  # type: ignore[attr-defined]
1643                "All sharded parameters that received a gradient in the "
1644                "post-backward should use `_saved_grad_shard`",
1645            )
1646        # Delete `_saved_grad_shard` since its existence indicates a previous
1647        # gradient to accumulate with in the post-backward hook
1648        if hasattr(flat_param, "_saved_grad_shard"):
1649            delattr(flat_param, "_saved_grad_shard")
1650
1651    @contextlib.contextmanager
1652    def to_cpu(self):
1653        """
1654        Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit.
1655
1656        For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
1657        since (1) there is no reason to include the padding in the copy and (2)
1658        there is no use case for the sharded flat parameter.
1659
1660        Precondition: ``self.flat_param`` 's data is the unpadded unsharded
1661        flat parameter on the compute device, and the handle uses a sharded
1662        strategy.
1663        Postcondition: Same as the precondition.
1664        """
1665        self._check_sharded_strategy()
1666        _p_assert(
1667            self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1668            f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1669        )
1670        self._check_on_compute_device(self.flat_param)
1671        # Check that the unpadded unsharded flat parameter is a view into the
1672        # padded unsharded flat parameter as expected
1673        # NOTE: This check is not strictly needed for correctness but is a
1674        # useful sanity check since the tensor should only be used internally.
1675        _p_assert(
1676            _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()),
1677            "Expects the unpadded parameter to be a view into the padded parameter",
1678        )
1679        self.flat_param_to(torch.device("cpu"))
1680        self._free_unsharded_flat_param()
1681        try:
1682            yield
1683        finally:
1684            _p_assert(
1685                self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
1686                f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
1687            )
1688            padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
1689            # Copy from CPU to the compute device
1690            padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
1691                self.flat_param
1692            )
1693            self._use_unsharded_flat_param(padded_unsharded_flat_param)
1694
1695    def reshard(self, free_unsharded_flat_param: bool):
1696        """
1697        Run the reshard logic.
1698
1699        This includes freeing the unsharded flat
1700        parameter if ``free_unsharded_flat_param`` and switching to using the
1701        sharded flat parameter. Note that this also implicitly offloads
1702        the sharded flat parameter (if CPU offload is enabled) by pointing
1703        it to the ``_local_shard`` attribute which resides on CPU.
1704        """
1705        # Switch to the sharded `FlatParameter` before freeing to prevent
1706        # "use-after-free"-type bugs with external profiling tools, where for
1707        # `use_orig_params=True`, the `param` does not point to valid memory
1708        # when setting `param.data = ...` in `_use_sharded_views()`.
1709        self._use_sharded_flat_param()
1710        if free_unsharded_flat_param:
1711            self._free_unsharded_flat_param()
1712
1713    def post_reshard(self):
1714        """
1715        Run the post-reshard logic.
1716
1717        This includes freeing any memory that
1718        can now be freed given that the ``FlatParameter`` points to the full
1719        precision sharded flat parameter.
1720
1721        Precondition: ``self.flat_param`` 's data points to the full precision
1722        sharded flat parameter.
1723        """
1724        # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
1725        # is also the low precision *unsharded* flat parameter. Hence, we delay
1726        # the free until the reshard.
1727        if (
1728            self._uses_param_mixed_precision
1729            and not self.uses_sharded_strategy
1730            and not self._force_full_precision  # did not use the low precision shard
1731        ):
1732            self._free_low_precision_sharded_param()
1733
1734    def _free_unsharded_flat_param(self):
1735        """
1736        Free the padded unsharded flat parameter. We allow this
1737        function to be called even when storage is not allocated
1738
1739        The tensor to free depends
1740        on the calling context since the unshard may have forced full
1741        precision, in which case a different tensor is used.
1742        """
1743        self._check_sharded_strategy()
1744        unsharded_flat_param = self._get_padded_unsharded_flat_param()
1745        self._check_on_compute_device(unsharded_flat_param)
1746        # Do not free the memory until all ops in the current stream finish
1747        _no_dispatch_record_stream(
1748            unsharded_flat_param, self._device_handle.current_stream()
1749        )
1750        _free_storage(unsharded_flat_param)
1751
1752    def _use_sharded_flat_param(self) -> None:
1753        """Switches to using the sharded flat parameter."""
1754        flat_param = self.flat_param
1755        if self._use_orig_params:
1756            in_forward = self._training_state == HandleTrainingState.FORWARD
1757            skip_use_sharded_views = (
1758                torch.is_grad_enabled()
1759                and in_forward
1760                and self._sharding_strategy
1761                in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
1762            )
1763            # Only incur the extra `.data` call if needed
1764            if skip_use_sharded_views:
1765                unsharded_flat_param = flat_param.data
1766        if self._offload_params:
1767            device = flat_param._local_shard.device  # type: ignore[attr-defined]
1768            _p_assert(
1769                device == torch.device("cpu"),
1770                f"Expects the local shard to be on CPU but got {device}",
1771            )
1772        flat_param.data = flat_param._local_shard  # type: ignore[attr-defined]
1773        if self._use_orig_params:
1774            if skip_use_sharded_views:  # type: ignore[possibly-undefined]
1775                self._unsharded_flat_param_for_skipped_views = unsharded_flat_param  # type: ignore[possibly-undefined]
1776            else:
1777                self._use_sharded_views()
1778            # For the post-forward reshard, we may try to use sharded gradient
1779            # views (or unsharded gradient views if a gradient was accumulated
1780            # in `no_sync()`), but for the post-backward reshard, we delay the
1781            # call to after the reduce-scatter.
1782            if (
1783                in_forward  # type: ignore[possibly-undefined]
1784                # Skip using gradient views if skipped using sharded views
1785                # since exposing unsharded parameters with sharded gradients
1786                # may be confusing to the user
1787                and not self._skipped_use_sharded_views
1788            ):
1789                # TODO: Change `_unpadded_unsharded_size` if we change the
1790                # gradient to be computed directly with padding.
1791                accumulated_grad_in_no_sync = (
1792                    flat_param.grad is not None
1793                    and self.uses_sharded_strategy
1794                    and flat_param.grad.shape == flat_param._unpadded_unsharded_size
1795                )
1796                if accumulated_grad_in_no_sync:
1797                    self._use_unsharded_grad_views()
1798                else:
1799                    self._use_sharded_grad_views()
1800
1801    #########
1802    # VIEWS #
1803    #########
1804    @no_type_check
1805    def _get_unflat_views_unaligned(
1806        self,
1807        tensor: Optional[torch.Tensor] = None,
1808    ) -> Iterator[Tensor]:
1809        """
1810        Return unflattened ``Tensor`` views into ``tensor``.
1811
1812        If `tensor`` is ``None``,  ``flat_param`` is used. The unflattening is based
1813        on ``flat_param`` 's metadata.
1814
1815        Examples for ``tensor`` include ``flat_param.grad`` or unsharded
1816        tensor optimizer state.
1817        """
1818        flat_param = self.flat_param
1819        if tensor is None:
1820            tensor = flat_param
1821        views = (
1822            _ext_post_unflatten_transform(
1823                subtensor.view(shape),
1824                param_extension,
1825                self._fsdp_extension,
1826            )
1827            for (subtensor, shape, param_extension) in zip(
1828                torch.split(tensor, flat_param._numels, dim=0),
1829                flat_param._shapes,
1830                flat_param._param_extensions,
1831            )
1832        )
1833        return views
1834
1835    @no_type_check
1836    def _get_unflat_views_aligned(
1837        self,
1838        tensor: Optional[Tensor] = None,
1839    ) -> List[Tensor]:
1840        """
1841        Return unflattened ``Tensor`` views into ``tensor`` with handling for padding.
1842
1843        This method has the same contract as :meth:`_get_unflat_views_unaligned`
1844        except it checks for ``None`` placeholders representing padding for
1845        alignment, which may incur slightly more CPU overhead.
1846        """
1847        flat_param = self.flat_param
1848        if tensor is None:
1849            tensor = flat_param
1850        splits: List[Tensor] = torch.split(
1851            tensor, flat_param._numels_with_padding, dim=0
1852        )
1853        idx = 0
1854        views: List[Tensor] = []
1855        for split, is_padding in zip(splits, flat_param._is_padding_mask):
1856            if is_padding:
1857                continue
1858            views.append(
1859                _ext_post_unflatten_transform(
1860                    split.view(flat_param._shapes[idx]),
1861                    flat_param._param_extensions[idx],
1862                    self._fsdp_extension,
1863                )
1864            )
1865            idx += 1
1866        return views
1867
1868    @no_type_check
1869    @torch.enable_grad()
1870    def _use_unsharded_views(self, as_params: bool) -> None:
1871        """
1872        Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it.
1873
1874        Args:
1875            as_params (bool): If ``True``, then registers the original
1876                parameters as ``nn.Parameter`` s; if ``False``, then registers
1877                the original parameters only as ``Tensor`` s. ``False`` should
1878                be used during forward/backward computation and when hiding the
1879                original parameters from :meth:`nn.Module.named_parameters`.
1880
1881        Note:
1882            when prefetching for next forward, current forward may be
1883            annotated with `@torch.no_grad()`
1884            `@torch.enable_grad()` ensures non-empty `view.grad_fn`
1885            otherwise `_post_backward_hook` will not get called
1886        """
1887        flat_param = self.flat_param
1888        self._check_unsharded(flat_param)
1889        views = self._get_unflat_views()
1890        from torch.distributed.tensor import DTensor
1891
1892        for i, (view, (param_name, module, _)) in enumerate(
1893            zip(views, flat_param._param_infos)
1894        ):
1895            if self._use_orig_params and as_params:
1896                if type(view) is DTensor:
1897                    # A `DTensor` `view` is not compatible with assigning
1898                    # `param.data = view`, so we cannot preserve the parameter
1899                    # variable.
1900                    self._setattr_param(
1901                        module,
1902                        param_name,
1903                        nn.Parameter(view, requires_grad=flat_param.requires_grad),
1904                    )
1905                    continue
1906                param = self.flat_param._params[i]
1907                self._setattr_param(module, param_name, param)
1908                param.data = view
1909            elif as_params:
1910                self._setattr_param(
1911                    module,
1912                    param_name,
1913                    nn.Parameter(view, requires_grad=flat_param.requires_grad),
1914                )
1915            else:  # `as_params=False`
1916                param_var: Tensor = view
1917                if self._use_orig_params:
1918                    if self._training_state == HandleTrainingState.FORWARD:
1919                        # Save the `Tensor` for the pre-backward
1920                        self.flat_param._tensors[i] = view  # save for pre-backward
1921                    elif self._training_state == HandleTrainingState.BACKWARD_PRE:
1922                        # Use the saved `Tensor` variable from the forward to
1923                        # preserve the autograd graph so that the post-backward
1924                        # hook fires (e.g. for reentrant AC)
1925                        tensor = self.flat_param._tensors[i]
1926                        tensor.data = view
1927                        param_var = tensor
1928                self._setattr_tensor(module, param_name, param_var)
1929                if (
1930                    self._use_orig_params
1931                    and self._training_state == HandleTrainingState.FORWARD
1932                ):
1933                    module._parameters[param_name] = param_var
1934        for i, (
1935            param_name,
1936            module,
1937            _,
1938            prim_param_name,
1939            prim_module,
1940            _,
1941        ) in enumerate(self.flat_param._shared_param_infos):
1942            prim_param: Union[Tensor, nn.Parameter] = getattr(
1943                prim_module, prim_param_name
1944            )
1945            _p_assert(
1946                not as_params or isinstance(prim_param, nn.Parameter),
1947                f"as_params={as_params} type(prim_param)={type(prim_param)}",
1948            )
1949            if self._use_orig_params and as_params:
1950                shared_param = self.flat_param._shared_params[i]
1951                self._setattr_param(module, param_name, shared_param)
1952                shared_param.data = prim_param
1953            elif as_params:
1954                self._setattr_param(module, param_name, prim_param)
1955            else:
1956                self._setattr_tensor(module, param_name, prim_param)
1957                if (
1958                    self._use_orig_params
1959                    and self._training_state == HandleTrainingState.FORWARD
1960                ):
1961                    module._parameters[param_name] = prim_param
1962
1963    @no_type_check
1964    def _use_unsharded_grad_views(self) -> None:
1965        """
1966        Unflatten the unsharded flat parameter's gradient.
1967
1968        The original parameter variables' gradients are set to be views into
1969        the unsharded flat parameter's gradient.
1970        """
1971        # Expects the gradient to be in `flat_param.grad`
1972        if self.flat_param.grad is None:
1973            for param in chain(self.flat_param._params, self.flat_param._shared_params):
1974                param.grad = None
1975            return
1976        self._check_unsharded(self.flat_param.grad)
1977        views = self._get_unflat_views(self.flat_param.grad)
1978        for i, (view, (param_name, module, _)) in enumerate(
1979            zip(views, self.flat_param._param_infos)
1980        ):
1981            _p_assert(
1982                hasattr(module, param_name),
1983                f"{self.flat_param._fqns[i]} is missing",
1984            )
1985            param = getattr(module, param_name)
1986            if (
1987                param.shape != view.shape
1988                or param.dtype != view.dtype
1989                or param.device != view.device
1990            ):
1991                # NOTE: This is a hack using `.data` to side step the check
1992                # that parameter/gradient sizes/dtypes/devices match. From
1993                # calling `reshard()`, `param` has the sharded size, has the
1994                # full precision dtype, and if CPU offloading is enabled, is on
1995                # CPU. Thus, one or more of the following cases can hold when
1996                # in `no_sync()`, where `view` is the original parameter's
1997                # gradient:
1998                # 1. `view` can have the unsharded size.
1999                # 2. `view` can have the parameter low precision dtype.
2000                # 3. `view` can be on GPU.
2001                if param.grad is None:
2002                    param.grad = torch.empty_like(param)
2003                param.grad.data = view
2004            else:
2005                param.grad = view
2006        for i, (
2007            param_name,
2008            module,
2009            module_name,
2010            prim_param_name,
2011            prim_module,
2012            _,
2013        ) in enumerate(self.flat_param._shared_param_infos):
2014            _p_assert(
2015                hasattr(module, param_name),
2016                f"{module_name + '.' + param_name if module_name else param_name} is missing",
2017            )  # did not save FQN info in `_shared_param_infos`
2018            param = getattr(module, param_name)
2019            prim_param = getattr(prim_module, prim_param_name)
2020            if (
2021                param.shape != prim_param.grad.shape
2022                or param.dtype != prim_param.grad.dtype
2023                or param.device != prim_param.grad.device
2024            ):
2025                # NOTE: This is the same hack to use `.data` to side step the
2026                # size check.
2027                if param.grad is None:
2028                    param.grad = torch.empty_like(param)
2029                param.grad.data = prim_param.grad
2030            else:
2031                param.grad = prim_param.grad
2032
2033    @contextlib.contextmanager
2034    def unflatten_as_params(self) -> Generator:
2035        """
2036        Unflatten the original parameters.
2037
2038        The function assumes that the flat parameter is unsharded. When in the context,
2039        unflattens the original parameters as ``nn.Parameter`` views into the
2040        flat parameter, and after the context, restores the original parameters
2041        as ``Tensor`` views into the flat parameter.
2042        """
2043        self._use_unsharded_views(as_params=True)
2044        try:
2045            yield
2046        finally:
2047            self._use_unsharded_views(as_params=False)
2048
2049    @no_type_check
2050    @torch.no_grad()
2051    def _use_sharded_views(self) -> None:
2052        """
2053        Set the original parameter variables' data to be flattened views into the sharded flat parameter.
2054
2055        The views are kept as flattened to simplify the case where a parameter
2056        is sharded across ranks. Parameters whose data is not present in the
2057        sharded flat parameter have their data set to a size-0 empty tensor. We
2058        do not delete them to ensure to preserve expected behaviors like model
2059        printability. Parameters whose data is present must preserve their
2060        variables to be passable to an optimizer.
2061        """
2062        self._unsharded_flat_param_for_skipped_views = None
2063        if not self.uses_sharded_strategy:
2064            # For `NO_SHARD`, use the *unflattened* unsharded views since we
2065            # have the unsharded parameter
2066            self._use_unsharded_views(as_params=True)
2067            return
2068        flat_param = self.flat_param
2069        self._check_sharded(flat_param)
2070        # Construct once and reuse for all parameters not in the local shard
2071        size_0_empty_tensor = torch.empty(
2072            0,
2073            dtype=self.flat_param.dtype,  # in case `flat_param` changed dtype
2074            device=self.flat_param.device,
2075            requires_grad=False,
2076        )
2077        for param, shard_param_info, (param_name, module, _) in zip(
2078            flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
2079        ):
2080            self._setattr_param(module, param_name, param)
2081            if not shard_param_info.in_shard:
2082                # Allow the original data to be freed via garbage collection
2083                param.data = size_0_empty_tensor
2084            else:
2085                offset = shard_param_info.offset_in_shard
2086                numel_in_shard = shard_param_info.numel_in_shard
2087                param.data = flat_param[offset : offset + numel_in_shard]
2088        assert self.flat_param._shared_params is not None
2089        for i, (
2090            param,
2091            (param_name, module, _, prim_param_name, prim_module, _),
2092        ) in enumerate(
2093            zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
2094        ):
2095            self._setattr_param(module, param_name, param)
2096            prim_param = getattr(prim_module, prim_param_name)
2097            param.data = prim_param  # could be both empty and non-empty
2098        if self._training_state == HandleTrainingState.BACKWARD_POST:
2099            # Clear the saved `Tensor`s since they are unneeded now
2100            for i in range(len(self.flat_param._tensors)):
2101                self.flat_param._tensors[i] = None
2102
2103    @no_type_check
2104    @torch.no_grad()
2105    def _use_sharded_grad_views(self) -> None:
2106        """
2107        Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient.
2108
2109        This is a no-op if there is no gradient.
2110
2111        Parameters whose data is not present in the sharded flat parameter and
2112        parameters with ``requires_grad=False`` have their gradients set to
2113        ``None``. Since the gradient variables do not need to be preserved,
2114        this method does not manipulate existing ``Tensor`` data directly and
2115        creates new ``Tensor`` variables instead.
2116        """
2117        flat_param = self.flat_param
2118        self._check_sharded(flat_param)
2119        grad = self.sharded_grad
2120        if grad is None:
2121            for param in chain(flat_param._params, flat_param._shared_params):
2122                param.grad = None
2123            return
2124        self._check_sharded(grad)
2125        for param, shard_param_info, is_grad_none in zip(
2126            flat_param._params,
2127            flat_param._shard_param_infos,
2128            flat_param._is_grad_none_mask,
2129        ):
2130            if not shard_param_info.in_shard:
2131                param.grad = None
2132            else:
2133                numel_in_shard = shard_param_info.numel_in_shard
2134                if param.requires_grad and not is_grad_none:
2135                    offset = shard_param_info.offset_in_shard
2136                    if self._keep_low_precision_grads or param.dtype != grad.dtype:
2137                        # NOTE: This is a hack using `.data` to side step the
2138                        # check that parameter/gradient dtypes match. Here,
2139                        # `param` has full precision; `grad` has low precision.
2140                        if param.grad is None:
2141                            # `.grad` must have the same shape as `param`
2142                            param.grad = torch.empty_like(param)
2143                        param.grad.data = grad[
2144                            offset : offset + numel_in_shard
2145                        ].reshape(param.shape)
2146                    else:
2147                        param.grad = grad[offset : offset + numel_in_shard].reshape(
2148                            param.shape
2149                        )
2150                else:
2151                    param.grad = None
2152        assert flat_param._shared_params is not None
2153        for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
2154            zip(flat_param._shared_params, flat_param._shared_param_infos)
2155        ):
2156            in_sharded_flat_param = hasattr(prim_module, prim_param_name)
2157            if in_sharded_flat_param and param.requires_grad:
2158                prim_param = getattr(prim_module, prim_param_name)
2159                param.grad = prim_param.grad  # share the same reference
2160            else:
2161                param.grad = None
2162
2163    @no_type_check
2164    @torch.no_grad()
2165    def _writeback_orig_params(self) -> bool:
2166        """
2167        Write back any parameters that changed storage to the handle's ``FlatParameter``.
2168
2169        Iterates over the original parameters and writes back any parameters
2170        that changed storages (due to a non-inplace operator) to the handle's
2171        ``FlatParameter``. This method preserves the ``FlatParameter` 's
2172        device even if an original parameter's device changes.
2173
2174        Raises:
2175            RuntimeError: If an original parameter or gradient changes storages
2176            but no longer has the expected flattened shape.
2177        Returns: ``True`` if some writeback happened, and ``False`` otherwise.
2178        """
2179        if (
2180            self.uses_sharded_strategy
2181            and not self.is_sharded(self.flat_param)
2182            and not self._skipped_use_sharded_views
2183        ):
2184            # For `NO_SHARD`, we may still need to writeback
2185            return False
2186        flat_param = self.flat_param
2187        wroteback = False
2188        if self._skipped_use_sharded_views and self.uses_sharded_strategy:
2189            # NOTE: We must use the unsharded flat parameter from which the
2190            # unsharded views were computed, not the one from the current
2191            # calling context (`_get_padded_unsharded_flat_param()`) since that
2192            # may be different (e.g. the model changed from train to eval).
2193            flat_param_tensor = self._unsharded_flat_param_for_skipped_views
2194            _p_assert(
2195                _data_ptr_allocated(flat_param_tensor),
2196                "If skipped using sharded views, the unsharded flat parameter "
2197                "should be allocated",
2198            )
2199        else:
2200            flat_param_tensor = flat_param
2201        # NOTE: Since this method is called in the pre-unshard, which is only
2202        # called during computation in the pre-forward or pre-backward, the
2203        # sharded gradient should be guaranteed to be in `.grad`, not in
2204        # `._saved_grad_shard`.
2205        flat_param_grad = (
2206            flat_param.grad
2207            if self.uses_sharded_strategy or not self._offload_params
2208            else flat_param._cpu_grad
2209        )
2210        for i, (
2211            param,
2212            (in_shard, offset_in_shard, numel_in_shard, _, _),
2213            (param_name, module, _),
2214        ) in enumerate(
2215            zip(
2216                flat_param._params,
2217                flat_param._shard_param_infos,
2218                flat_param._param_infos,
2219            )
2220        ):
2221            if not in_shard:
2222                continue
2223            if not hasattr(module, param_name):
2224                # Do not writeback if original parameters are deregistered
2225                # (e.g. during model checkpointing)
2226                continue
2227
2228            # Check for parameter writeback
2229            if self._skipped_use_sharded_views:
2230                param = flat_param._tensors[i]
2231                _p_assert(
2232                    param is not None,
2233                    f"Expects to have saved tensor for {flat_param._fqns[i]}",
2234                )
2235            param_changed = getattr(module, param_name) is not param
2236            needs_param_writeback = (
2237                param_changed  # changed parameter variable itself
2238                or not _same_storage(param, flat_param_tensor)
2239            )
2240            if self._skipped_use_sharded_views and (
2241                param_changed or needs_param_writeback
2242            ):
2243                raise AssertionError(
2244                    "FSDP does not support changing the parameters between "
2245                    f"forward and backward for {self._sharding_strategy}"
2246                )
2247            if param_changed:
2248                # NOTE: The gradient is not preserved after a parameter change.
2249                param = getattr(module, param_name)
2250                flat_param._params[i] = param
2251            if needs_param_writeback:
2252                expected_shape = torch.Size([numel_in_shard])
2253                self._writeback_tensor(
2254                    param, flat_param, i, expected_shape, offset_in_shard, True
2255                )
2256                wroteback = True
2257
2258            # Check for gradient writeback
2259            if self._skipped_use_sharded_views:
2260                # Skip the writeback check because we do not expose gradients
2261                # when we skipped using sharded views
2262                continue
2263            if param.grad is None and flat_param.grad is not None:
2264                expected_shape = torch.Size([numel_in_shard])
2265                self._writeback_tensor(
2266                    None, flat_param.grad, i, expected_shape, offset_in_shard, False
2267                )
2268            elif param.grad is not None:
2269                # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
2270                # memory and owns the gradient storage, so it will never
2271                # require gradient writeback.
2272                if not self.uses_sharded_strategy and self._offload_params:
2273                    # Explicitly continue to handle the case of `no_sync()`,
2274                    # where `param.grad` is a view into the GPU gradient
2275                    # referenced by `flat_param.grad`, while `flat_param_grad`
2276                    # is `flat_param._cpu_grad`, which is on CPU
2277                    continue
2278
2279                needs_grad_writeback = flat_param_grad is None or not _same_storage(
2280                    param.grad, flat_param_grad
2281                )
2282                if needs_grad_writeback:
2283                    if flat_param_grad is None:
2284                        flat_param_grad = torch.zeros_like(flat_param)
2285                    expected_shape = torch.Size([numel_in_shard])
2286                    self._writeback_tensor(
2287                        param.grad,
2288                        flat_param_grad,
2289                        i,
2290                        expected_shape,
2291                        offset_in_shard,
2292                        False,
2293                    )
2294                    flat_param.grad = flat_param_grad
2295                    flat_param_grad = flat_param.grad
2296
2297        # TODO: If we want to handle shared parameters, we need to re-generate
2298        # the shared parameter data structures in case sharedness changed.
2299        for i, (
2300            param_name,
2301            module,
2302            _,
2303            prim_param_name,
2304            prim_module,
2305            _,
2306        ) in enumerate(flat_param._shared_param_infos):
2307            if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
2308                raise NotImplementedError(
2309                    "Changing shared parameters is not supported yet"
2310                )
2311        return wroteback
2312
2313    def _writeback_tensor(
2314        self,
2315        src_tensor: Optional[Tensor],
2316        dst_tensor: Tensor,
2317        tensor_index: int,
2318        expected_shape: torch.Size,
2319        offset: int,
2320        is_param: bool,  # else gradient
2321    ) -> None:
2322        """
2323        Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``.
2324
2325        ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if
2326        ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
2327        instead of copying. ``tensor_index`` gives the index of ``src_tensor``
2328        in the metadata structures.
2329
2330        Raises:
2331            RuntimeError: If the ``src_tensor`` does not have the expected
2332            shape.
2333        """
2334        _p_assert(
2335            len(expected_shape) == 1,
2336            f"Expects a 1D expected shape but got {expected_shape}",
2337        )
2338        if self._debug_level == dist.DebugLevel.INFO:
2339            rank = self.rank if hasattr(self, "rank") else dist.get_rank()
2340            src_shape = src_tensor.shape if src_tensor is not None else None
2341            src_device = src_tensor.device if src_tensor is not None else None
2342            warnings.warn(
2343                f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
2344                f"writeback in {self._training_state}\n"
2345                f"expected shape={expected_shape} shape={src_shape} "
2346                f"expected device={dst_tensor.device} device={src_device}"
2347            )
2348        if src_tensor is not None and src_tensor.shape != expected_shape:
2349            # NOTE: Gradient shape mismatch is not possible in practice since
2350            # the gradient shape is enforced to match that of the parameter and
2351            # we already check for parameter shape mismatch.
2352            raise RuntimeError(
2353                f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
2354                f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
2355            )
2356        if src_tensor is not None:
2357            dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
2358        else:
2359            dst_tensor[offset : offset + expected_shape.numel()].zero_()
2360            assert self.flat_param._is_grad_none_mask is not None
2361            self.flat_param._is_grad_none_mask[tensor_index] = True
2362
2363    def _reset_flat_param_grad_info_if_needed(self):
2364        """
2365        Reset ``flat_param.grad`` if needed.
2366
2367        When ``use_orig_params=True``:
2368        (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the
2369        original parameters' ``.grad`` are ``None``, and
2370        (2) sets ``flat_param.requires_grad=False`` if *none* of the original
2371        parameters require gradient.
2372        For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in
2373        which case we want to free the gradients as soon after the
2374        ``zero_grad()`` call as possible.
2375        """
2376        if not self._use_orig_params:
2377            return
2378        flat_param = self.flat_param
2379        assert flat_param._params is not None  # mypy
2380        all_grad_none = True
2381        requires_grad = False
2382        for param in flat_param._params:
2383            all_grad_none &= param.grad is None
2384            requires_grad |= param.requires_grad
2385        if all_grad_none:
2386            flat_param.grad = None
2387        # As long as one parameter requires gradient, then the flat parameter
2388        # must require gradient
2389        flat_param.requires_grad = requires_grad
2390
2391    def _deregister_orig_params(self):
2392        for param_info in self.flat_param._param_infos:
2393            param_name, module, _ = param_info
2394            if hasattr(module, param_name):
2395                delattr(module, param_name)
2396        for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
2397            if hasattr(module, param_name):
2398                delattr(module, param_name)
2399
2400    ###########
2401    # HELPERS #
2402    ###########
2403    def flat_param_to(self, *args, **kwargs):
2404        """Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
2405        self.flat_param.data = self.flat_param.to(*args, **kwargs)
2406        if self._use_orig_params:
2407            # Refresh the views because their storage may have changed
2408            if self.is_sharded(self.flat_param):
2409                self._use_sharded_views()
2410            else:
2411                self._use_unsharded_views(as_params=True)
2412
2413    def _get_modules(self) -> Set[nn.Module]:
2414        """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter."""
2415        return {pi.module for pi in self.flat_param._param_infos}.union(
2416            {spi.module for spi in self.flat_param._shared_param_infos}
2417        )
2418
2419    def is_sharded(self, tensor: Tensor) -> bool:
2420        """
2421        Return whether ``tensor`` is *currently* sharded.
2422
2423        For ``NO_SHARD``, we choose to have this always return ``False`` for clarity.
2424        """
2425        if (
2426            not hasattr(self.flat_param, "_sharded_size")
2427            or not self.uses_sharded_strategy
2428        ):
2429            # `_sharded_size` is defined iff `handle.shard()` has been called
2430            return False
2431        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
2432        return tensor.size() == sharded_size
2433
2434    def param_module_names(self) -> Iterator[Tuple[str, str]]:
2435        shared_param_infos = [
2436            ParamInfo(param_name, module, module_name)
2437            for (
2438                param_name,
2439                module,
2440                module_name,
2441                _,
2442                _,
2443                _,
2444            ) in self.flat_param._shared_param_infos
2445        ]
2446        for param_info in chain(self.flat_param._param_infos, shared_param_infos):
2447            param_name, _, module_name = param_info  # type: ignore[misc]
2448            yield (param_name, module_name)
2449
2450    def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
2451        for param_name, _, module_name in [
2452            ParamInfo(param_name, module, module_name)
2453            for (
2454                param_name,
2455                module,
2456                module_name,
2457                _,
2458                _,
2459                _,
2460            ) in self.flat_param._shared_param_infos
2461        ]:
2462            yield (param_name, module_name)
2463
2464    @property
2465    def _fqns_in_shard(self) -> List[str]:
2466        """Return the FQNs of the parameters present in this rank's shard."""
2467        fqns_in_shard: List[str] = []
2468        for fqn, shard_param_info in zip(
2469            self.flat_param._fqns, self.flat_param._shard_param_infos  # type: ignore[attr-defined]
2470        ):
2471            if shard_param_info.in_shard:
2472                fqns_in_shard.append(fqn)
2473        return fqns_in_shard
2474
2475    @property
2476    def sharded_grad(self) -> Optional[Tensor]:
2477        """Return the handle's sharded gradient."""
2478        flat_param = self.flat_param
2479        # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
2480        # - CPU offloading: `_cpu_grad`
2481        # - No CPU offloading + sharded strategies: `_saved_grad_shard`
2482        # - No CPU offloading + `NO_SHARD`: `grad`
2483        grad: Optional[Tensor]
2484        if hasattr(flat_param, "_cpu_grad"):
2485            grad = flat_param._cpu_grad  # type: ignore[attr-defined]
2486        elif hasattr(flat_param, "_saved_grad_shard"):
2487            # In the post-backward hook, the sharded gradient is still in
2488            # `_saved_grad_shard`.
2489            grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
2490        else:
2491            # If in IDLE or in FORWARD states, then there may be an
2492            # (accumulated) gradient. If accessed in IDLE, then this should
2493            # be due to re-registering the original parameters (e.g. in state
2494            # dict load).
2495            _p_assert(
2496                flat_param.grad is None
2497                or not self.uses_sharded_strategy
2498                or self._training_state
2499                in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
2500                "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
2501                "unless in IDLE or FORWARD",
2502            )
2503            grad = flat_param.grad
2504        return grad
2505
2506    def _reset_is_grad_none(self) -> None:
2507        """
2508        Reset ``_is_grad_none_mask`` as needed.
2509
2510        This method should only be
2511        called in the post-backward after gradient computation, in which case
2512        if a parameter requires gradient, then it will surely receive a
2513        gradient and we may reset its mask entry to ``False``.
2514        """
2515        if not self._use_orig_params:
2516            return
2517        _p_assert(
2518            self._training_state == HandleTrainingState.BACKWARD_POST,
2519            "Expects to only be called in the post-backward after gradient computation",
2520        )
2521        flat_param = self.flat_param
2522        assert flat_param._params is not None  # mypy
2523        for i, param in enumerate(flat_param._params):  # type: ignore[arg-type]
2524            # As long as the parameter requires gradient, it should receive a
2525            # meaningful gradient (even if the gradient happens to be zeros)
2526            if param.requires_grad:
2527                assert flat_param._is_grad_none_mask is not None  # mypy
2528                flat_param._is_grad_none_mask[i] = False
2529
2530    #######################
2531    # CHECKS & INVARIANTS #
2532    #######################
2533    def _check_sharded_strategy(self):
2534        _p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
2535
2536    def _check_on_compute_device(self, tensor: Tensor):
2537        _p_assert(
2538            tensor.device == self.device,
2539            f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}",
2540        )
2541
2542    def _check_on_cpu(self, tensor: Tensor):
2543        _p_assert(
2544            tensor.device == torch.device("cpu"),
2545            f"Expects tensor to be on CPU but got {tensor.device}",
2546        )
2547
2548    @staticmethod
2549    def _check_storage_freed(tensor: Tensor):
2550        # Compile does not resize during trace
2551        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
2552            _p_assert(
2553                _same_storage_size(tensor, 0),
2554                "Expects storage to be freed but got storage with size > 0",
2555            )
2556
2557    @staticmethod
2558    def _check_storage_allocated(tensor: Tensor):
2559        _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated")
2560
2561    def _check_low_precision_shard(self):
2562        _p_assert(
2563            self._uses_param_mixed_precision,
2564            "Not using low precision for parameters",
2565        )
2566        _p_assert(
2567            getattr(self.flat_param, "_mp_shard", None) is not None,
2568            "Expects `_mp_shard` to exist",
2569        )
2570        device = self.flat_param._mp_shard.device  # type: ignore[attr-defined]
2571        _p_assert(
2572            device == self.device,
2573            f"Expects the low precision shard to be on {self.device} but got {device}",
2574        )
2575
2576    def _check_unsharded(self, tensor: Tensor):
2577        msg_prefix = "Expects tensor to be unsharded "
2578        _p_assert(tensor is not None, msg_prefix + "but got `None`")
2579        unsharded_size = self.flat_param._unpadded_unsharded_size
2580        _p_assert(
2581            tensor.size() == unsharded_size,
2582            msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
2583        )
2584
2585    def _check_sharded(self, tensor: Tensor):
2586        msg_prefix = "Expects tensor to be sharded "
2587        _p_assert(tensor is not None, msg_prefix + "but got `None`")
2588        sharded_size = self.flat_param._sharded_size  # type: ignore[attr-defined]
2589        _p_assert(
2590            tensor.size() == sharded_size,
2591            msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
2592        )
2593
2594    ##############
2595    # PROPERTIES #
2596    ##############
2597    @property
2598    def uses_sharded_strategy(self) -> bool:
2599        return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
2600
2601    @property
2602    def _uses_param_mixed_precision(self) -> bool:
2603        return self._fwd_bwd_param_dtype != self._orig_param_dtype
2604
2605    @property
2606    def _uses_reduce_mixed_precision(self) -> bool:
2607        return self._reduce_dtype != self._orig_param_dtype
2608
2609    @property
2610    def _force_full_precision(self) -> bool:
2611        return (
2612            self._uses_param_mixed_precision or self._uses_reduce_mixed_precision
2613        ) and (
2614            self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
2615            or
2616            # Also disable mixed precision in model eval mode, if configured
2617            (not self._fully_sharded_module.training and self._use_full_prec_in_eval)
2618        )
2619
2620    @property
2621    def _skipped_use_sharded_views(self) -> bool:
2622        """
2623        This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``.
2624
2625        This returns if this handle is
2626        currently in a state where it has skipped using sharded views, in which
2627        case it can restore view invariants via ``_use_sharded_views()``.
2628        """
2629        return self._unsharded_flat_param_for_skipped_views is not None
2630
2631
2632# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
2633def _unsafe_setattr_param(
2634    module: nn.Module, param_name: str, param: nn.Parameter
2635) -> None:
2636    module._parameters[param_name] = param
2637    # This bypasses any overrides in case `module` is an instance of an
2638    # `nn.Module` subclass
2639    super(nn.Module, module).__setattr__(param_name, param)
2640
2641
2642def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
2643    module._parameters.pop(param_name, None)
2644    # This bypasses any overrides in case `module` is an instance of an
2645    # `nn.Module` subclass
2646    super(nn.Module, module).__setattr__(param_name, tensor)
2647
2648
2649def _safe_setattr_tensor_or_param(
2650    module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
2651):
2652    # Call `delattr()` and `setattr()` to go through `nn.Module` checks
2653    if hasattr(module, param_name):
2654        delattr(module, param_name)
2655    setattr(module, param_name, tensor_or_param)
2656
2657
2658def _convert_to_params(
2659    tensors: List[Union[torch.Tensor, nn.Parameter]]
2660) -> List[nn.Parameter]:
2661    return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
2662
2663
2664def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
2665    return (
2666        param_or_tensor.detach()
2667        if isinstance(param_or_tensor, nn.Parameter)
2668        else param_or_tensor
2669    )
2670
2671
2672def _get_aligned_numel(unsharded_dtype: torch.dtype):
2673    # NOTE: This alignment constraint comes from TorchInductor.
2674    ALIGNMENT = 16  # bytes
2675    unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
2676    aligned_numel = ALIGNMENT // unsharded_dtype_size
2677    return aligned_numel
2678
2679
2680@functools.lru_cache(8)
2681def _get_dtype_size(dtype):
2682    return torch.empty((), dtype=dtype).element_size()
2683
2684
2685def _construct_padding_tensor(
2686    padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
2687):
2688    # NOTE: Set the padding value as a magic number for debuggability. The
2689    # value itself should never be used in any user-facing computation.
2690    return (
2691        torch.ones(
2692            (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
2693        )
2694        * _FLAT_PARAM_PADDING_VALUE
2695    )
2696
2697
2698# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning
2699# messasge is passed in)
2700@functools.lru_cache(1)
2701def _warn_skip_writeback_check(log: logging.Logger, warning: str):
2702    logger.warning(warning)
2703
2704
2705# Use `lru_cache(1)` to only log the warning once
2706@functools.lru_cache(1)
2707def _warn_use_fake_all_gather(log: logging.Logger, warning: str):
2708    logger.warning(warning)
2709
2710
2711# Use `lru_cache(1)` to only log the warning once
2712@functools.lru_cache(1)
2713def _warn_use_fake_reduce(log: logging.Logger, warning: str):
2714    logger.warning(warning)
2715
2716
2717def _same_storage(a, b):
2718    # Params are DTensors in backward
2719    # with SHARD_GRAD_OP + TP
2720    from torch.distributed.tensor import DTensor
2721
2722    if isinstance(a, DTensor):
2723        a = a._local_tensor
2724    if isinstance(b, DTensor):
2725        b = b._local_tensor
2726    return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr()
2727
2728
2729def _same_storage_size(a: torch.Tensor, b: int):
2730    return a.untyped_storage().size() // a.element_size() == b
2731
2732
2733def _storage_size_allocated(tensor: Tensor):
2734    storage_size: int = tensor.untyped_storage().size()
2735    return storage_size > 0
2736