xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_runtime_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import logging
4from enum import auto, Enum
5from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple
6
7import torch
8import torch.distributed as dist
9import torch.distributed.fsdp._traversal_utils as traversal_utils
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.autograd import Variable
13from torch.autograd.graph import register_multi_grad_hook
14from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
15from torch.distributed.fsdp._common_utils import (
16    _assert_in_training_states,
17    _FSDPState,
18    _get_module_fsdp_state,
19    _is_composable,
20    _log_post_backward_hook,
21    _no_dispatch_record_stream,
22    clean_tensor_name,
23    TrainingState,
24)
25from torch.distributed.fsdp._flat_param import (
26    FlatParameter,
27    FlatParamHandle,
28    HandleShardingStrategy,
29    HandleTrainingState,
30    RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES,
31)
32from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
33from torch.distributed.fsdp.api import BackwardPrefetch
34from torch.distributed.utils import (
35    _apply_to_tensors,
36    _cast_forward_inputs,
37    _p_assert,
38    _to_kwargs,
39)
40from torch.utils import _pytree as pytree
41
42
43logger = logging.getLogger(__name__)
44
45# Do not include "process_group" to enable hybrid shard and MoE cases
46HOMOGENEOUS_ATTR_NAMES = (
47    "_use_orig_params",
48    "limit_all_gathers",
49    "_use_full_prec_in_eval",
50)
51
52
53class _PrefetchMode(Enum):
54    BACKWARD = auto()
55    FORWARD = auto()
56
57
58def _get_fsdp_root_states_with_modules(
59    module: nn.Module,
60) -> Tuple[List[_FSDPState], List[nn.Module]]:
61    """
62    Returns a tuple containing:
63    1. A list of the root ``_FSDPState`` instances in the module tree rooted at
64    ``module`` without any duplicates and following the ``module.modules()``
65    traversal order (which is assumed to be depth-first).
66    2. A corresponding list of the root modules owning the states in the first
67    list.
68
69    This is similar to :func:`_get_fsdp_states_with_modules` except that we
70    must call :func:`_is_fsdp_root` to force a lazy initialization to determine
71    the FSDP root in case lazy initialization has not yet happened.
72    """
73    fsdp_root_states: List[_FSDPState] = []
74    fsdp_root_modules: List[nn.Module] = []
75    visited_fsdp_states: Set[_FSDPState] = set()
76    # NOTE: This function assumes that `module.modules()` proceeds top-down.
77    for submodule in module.modules():
78        optional_state = _get_module_fsdp_state(submodule)
79        if (
80            optional_state is not None
81            and optional_state not in visited_fsdp_states
82            and _is_fsdp_root(optional_state, submodule)
83        ):
84            visited_fsdp_states.add(optional_state)
85            fsdp_root_states.append(optional_state)
86            fsdp_root_modules.append(submodule)
87    return fsdp_root_states, fsdp_root_modules
88
89
90def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
91    """See :func:`_get_fsdp_root_states_with_modules`."""
92    fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
93    return fsdp_root_states
94
95
96def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
97    """
98    Returns if ``state`` corresponds to that of an FSDP root.
99
100    For the wrapper code path, ``state`` and ``module`` should be the same. For
101    the non-wrapper code path, ``state`` should be ``module`` 's state.
102    """
103    # Force a lazy initialization to determine the FSDP root
104    _lazy_init(state, module)
105    assert state._is_root is not None  # mypy
106    return state._is_root
107
108
109@no_type_check
110def _lazy_init(
111    state: _FSDPState,
112    root_module: nn.Module,
113) -> _FSDPState:
114    """
115    Performs initialization lazily, typically right before the first forward
116    pass. The laziness is needed to ensure that the parameter device/dtype and
117    the FSDP hierarchy have finalized. This method's actual logic only runs on
118    the root FSDP instance, which performs initialization for all non-root FSDP
119    instances to avoid partial initialization.
120
121    For the non-composable code path, ``state`` and ``root_module`` should be
122    the same, namely the FSDP instance itself.
123    """
124    if state._is_root is not None:
125        return  # no-op: already lazily initialized
126    if not state._device_handle.is_available():
127        # Allow the FSDP constructor to run even without CUDA but check this
128        # once we start real execution
129        raise RuntimeError("FSDP does not support CPU only execution")
130    # The following logic is only run on the root FSDP instance since it will
131    # set `_is_root=False` for the non-root instances
132    state._is_root = True
133    _assert_in_training_states(state, [TrainingState.IDLE])
134    _check_flat_params_on_expected_device(state, root_module)
135    state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module)
136    _init_streams(state)
137    buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
138    _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
139    state._exec_order_data.init(state, root_module, state.process_group)
140    _share_state_and_init_handle_attrs(state, root_module)
141    return state
142
143
144def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
145    """
146    Checks that all ``FlatParameter``s in ``module`` 's tree managed by
147    ``state`` are on the expected device for *lazy initialization*.
148    """
149    cpu_device = torch.device("cpu")
150    for handle in traversal_utils._get_fsdp_handles(module):
151        if (
152            not handle._offload_params
153            and handle.flat_param.device != state.compute_device
154        ):
155            raise RuntimeError(
156                "An FSDP-managed module unexpectedly has parameters on "
157                f"{handle.flat_param.device}. Make sure to move the module to "
158                f"{state.compute_device} before training."
159            )
160        elif handle._offload_params and handle.flat_param.device != cpu_device:
161            raise RuntimeError(
162                "An FSDP-managed module with parameter CPU offloading enabled "
163                f"has parameters on {handle.flat_param.device}. Make sure to "
164                f"not move the module from CPU when offloading parameters."
165            )
166
167
168@no_type_check
169def _share_state_and_init_handle_attrs(
170    root_state: _FSDPState,
171    root_module: nn.Module,
172) -> None:
173    """
174    Shares data structure state from the ``root_state`` to all FSDP states in
175    ``root_module`` 's module tree, and initializes handle attributes. These
176    are done together to require a single loop over the states.
177    """
178    handle = root_state._handle
179    if handle:
180        handle.init_flat_param_attributes()
181    attr_name_to_values: Dict[str, Set[Any]] = {}
182    for attr_name in HOMOGENEOUS_ATTR_NAMES:
183        attr_name_to_values[attr_name] = set()
184    root_state._all_handles = root_state._exec_order_data.all_handles  # share reference
185    # Update _has_optim_in_backward for each handle.
186    for handle in root_state._all_handles:
187        flat_param = handle.flat_param
188        if hasattr(flat_param, "_in_backward_optimizers"):
189            raise RuntimeError(
190                "FSDP optimizer in backward only supported with use_orig_params=True!"
191            )
192        handle._has_optim_in_backward = flat_param._params is not None and any(
193            hasattr(param, "_in_backward_optimizers") for param in flat_param._params
194        )
195        if handle._has_optim_in_backward:
196            torch._C._log_api_usage_once("fsdp.optimizer_in_backward")
197    for fsdp_state in root_state._all_fsdp_states:
198        for attr_name in HOMOGENEOUS_ATTR_NAMES:
199            _p_assert(
200                hasattr(fsdp_state, attr_name),
201                f"FSDP state missing attribute {attr_name}",
202            )
203            attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
204        if fsdp_state is root_state:
205            continue
206        # Relax the assert for non-root FSDP instances in case the nested
207        # initialized module is wrapped again in FSDP later (e.g. after
208        # training to run inference)
209        _p_assert(
210            fsdp_state._is_root is None or not fsdp_state._is_root,
211            "Non-root FSDP instance's `_is_root` should not have been "
212            "set yet or should have been set to `False`",
213        )
214        fsdp_state._is_root = False
215        fsdp_state._unshard_stream = root_state._unshard_stream
216        fsdp_state._post_backward_stream = root_state._post_backward_stream
217        fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream
218        fsdp_state._all_reduce_stream = root_state._all_reduce_stream
219        fsdp_state._default_stream = root_state._default_stream
220        fsdp_state._exec_order_data = root_state._exec_order_data
221        fsdp_state._free_event_queue = root_state._free_event_queue
222        if fsdp_state._fsdp_extension is not None:
223            fsdp_state._fsdp_extension.compute_stream = root_state._default_stream
224        handle = fsdp_state._handle
225        if handle:
226            handle.init_flat_param_attributes()
227    for attr_name, attr_values in attr_name_to_values.items():
228        if len(attr_values) != 1:
229            raise ValueError(
230                f"Expects one homogeneous value for {attr_name} but got {attr_values}"
231            )
232
233
234@no_type_check
235def _init_streams(
236    state: _FSDPState,
237) -> None:
238    """
239    Initializes CUDA streams for overlapping communication, computation, and
240    data transfers. The streams should be shared across FSDP instances.
241    """
242    assert state._is_root
243    assert state._device_handle.is_available()
244    uses_hybrid_sharding = any(
245        fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
246        for fsdp_state in state._all_fsdp_states
247    )
248    # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and
249    # preserve the default priority of 0 otherwise
250    high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0
251    # Default stream for computation
252    state._default_stream = state._device_handle.current_stream()
253    if state._fsdp_extension is not None:
254        # set the compute stream to the FSDP extension
255        state._fsdp_extension.compute_stream = state._default_stream
256
257    # Stream for unshard logic, including allocating the all-gather destination
258    # tensors and the all-gathers themselves
259    state._unshard_stream = state._device_handle.Stream(priority=high_priority)
260    # Stream for overlapping gradient reduction with the backward pass gradient
261    # computation
262    state._post_backward_stream = state._device_handle.Stream(priority=high_priority)
263    # Stream for pre-unshard logic, namely allocations and writes for CPU
264    # offloading (H2D copy) and mixed precision (low precision cast)
265    state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority)
266    # Stream to run HSDP's all-reduce as async (if using HSDP)
267    state._all_reduce_stream = (
268        state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream
269    )
270
271
272@no_type_check
273def _unshard(
274    state: _FSDPState,
275    handle: FlatParamHandle,
276    unshard_stream: torch.Stream,
277    pre_unshard_stream: torch.Stream,
278) -> None:
279    """
280    Unshards the handles in ``handles``. If the handles are in
281    :meth:`summon_full_params` and are using mixed precision, then they are
282    forced to full precision.
283
284    Postcondition: handle's ``FlatParameter`` 's data is the padded
285    unsharded flat parameter on the compute device.
286    """
287    if not handle:
288        return
289    with state._device_handle.stream(pre_unshard_stream):
290        ran_pre_unshard = handle.pre_unshard()
291    if ran_pre_unshard:
292        unshard_stream.wait_stream(pre_unshard_stream)
293    if state.limit_all_gathers:
294        event = state._free_event_queue.dequeue_if_needed()
295        if event:
296            with torch.profiler.record_function(
297                "FullyShardedDataParallel.rate_limiter"
298            ):
299                event.synchronize()
300    with state._device_handle.stream(unshard_stream):
301        handle.unshard()
302        handle.post_unshard()
303
304
305@no_type_check
306def _reshard(
307    state: _FSDPState,
308    handle: FlatParamHandle,
309    free_unsharded_flat_param: bool,
310):
311    """
312    Reshards the handle. ``free_unsharded_flat_param`` indicates whether to
313    free the handle's padded unsharded flat parameter.
314    """
315    handle.reshard(free_unsharded_flat_param)
316    if state.limit_all_gathers and free_unsharded_flat_param:
317        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
318            # We don't run a even queue for freeing under torch compile atm
319            # But maybe we need to? TODO(voz): Look into this
320            free_event = state._device_handle.Event()
321            free_event.record()
322            state._free_event_queue.enqueue(free_event)
323    handle.post_reshard()
324    # Flat parameter freed or not, we always have to "unshard" the parameter
325    # upon next access to get its shape correct.
326    handle._prefetched = False
327
328
329def _unshard_grads(
330    handle: Optional[FlatParamHandle],
331) -> None:
332    if handle:
333        handle.unshard_grad()
334
335
336def _reshard_grads(
337    handle: Optional[FlatParamHandle],
338) -> None:
339    if handle:
340        handle.reshard_grad()
341
342
343@no_type_check
344def _pre_forward(
345    state: _FSDPState,
346    handle: Optional[FlatParamHandle],
347    unshard_fn: Callable,
348    module: nn.Module,
349    args: Tuple[Any, ...],
350    kwargs: Dict[str, Any],
351) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
352    """
353    Runs the pre-forward logic. This includes an opportunity to unshard
354    currently sharded parameters such as those for the current forward and
355    registering post-backward hooks for these current parameters. This function
356    also converts forward ``args`` and ``kwargs`` to the given precision.
357
358    Args:
359        handles (List[FlatParamHandle]): Handles giving the parameters used in
360            the current forward.
361        unshard_fn (Optional[Callable]): A callable to unshard any currently
362            sharded parameters or ``None`` to not do any unsharding.
363        module (nn.Module): Module whose forward this method runs right before;
364            expected by the hook signature.
365        args (Tuple[Any, ...]): Module forward ``args``.
366        kwargs (Dict[str, Any]): Module forward ``kwargs``.
367    """
368    with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"):
369        # For `fully_shard` + `checkpoint`, skip pre-forward logic in the
370        # recomputed forward
371        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
372            # For both checkpoint implementations, we do not need to re-cast
373            # inputs here since they will be checkpointed in the low precision
374            # either by AC or normally by autograd as long as the AC region is
375            # nested within FSDP
376            return args, kwargs
377        state.training_state = TrainingState.FORWARD_BACKWARD
378        state._exec_order_data.record_pre_forward(handle, module.training)
379        if handle:
380            handle._training_state = HandleTrainingState.FORWARD
381        if unshard_fn is not None:
382            unshard_fn(state, handle)
383        # Register post-backward hooks to reshard the parameters and reduce-scatter
384        # their gradients. They must be re-registered every forward pass in case
385        # the `grad_fn` is mutated.
386        _register_post_backward_hook(state, handle)
387        # We have to reallocate the _cpu_grad if optimizer overlap
388        # set the grad to None in the backward pass.
389        if handle and handle._offload_params and handle.flat_param._cpu_grad is None:
390            handle.flat_param._cpu_grad = torch.zeros_like(
391                handle.flat_param._local_shard, device=torch.device("cpu")
392            ).pin_memory(device=state.compute_device)
393
394        should_cast_forward_inputs = (
395            state._handle and not state._handle._force_full_precision
396        )
397
398        if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs:
399            # Recursively convert args and kwargs to specified precision.
400            input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
401            args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
402        _register_post_backward_reshard_only_hook(state, handle, args, kwargs)
403        return args, kwargs
404
405
406@no_type_check
407def _pre_forward_unshard(
408    state: _FSDPState,
409    handle: Optional[FlatParamHandle],
410) -> None:
411    """Unshards parameters in the pre-forward."""
412    if not handle:
413        return
414    # If the handles have been prefetched, then there is no need to call
415    # `_unshard()` again
416    if not handle._prefetched:
417        _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
418    handle._needs_pre_forward_unshard = False
419    # Don't wait during trace
420    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
421        current_stream = state._device_handle.current_stream()
422        if state._unshard_event is not None:
423            current_stream.wait_event(state._unshard_event)
424            state._unshard_event = None
425        else:
426            current_stream.wait_stream(state._unshard_stream)
427    with torch.profiler.record_function(
428        "FullyShardedDataParallel._pre_forward_prefetch"
429    ):
430        _prefetch_handle(state, handle, _PrefetchMode.FORWARD)
431
432
433@no_type_check
434def _post_forward(
435    state: _FSDPState,
436    handle: Optional[FlatParamHandle],
437    reshard_fn: Callable,
438    module: nn.Module,
439    input: Any,
440    output: Any,
441) -> Any:
442    """
443    Runs the post-forward logic. This includes an opportunity to reshard
444    currently unsharded parameters such as those used in the current forward
445    and registering pre-backward hooks on the forward outputs.
446
447    Args:
448        handles (List[FlatParamHandle]): Handles giving the parameters used in
449            the current forward.
450        reshard_fn (Optional[Callable]): A callable to reshard any currently
451            unsharded parameters (e.g. from the current forward) or ``None`` to
452            not do any resharding.
453        module (nn.Module): Module whose forward just ran, which should be a
454            fully sharded module (see [Note: Fully Sharded Module]); expected
455            by the hook signature.
456        input (Any): Unused; expected by the hook signature.
457        output (Any): Forward pass output; pre-backward hooks are registered on
458            the tensors that require gradients in this output.
459
460    Postcondition: Each ``FlatParameter`` 's data points to the sharded flat
461    parameter.
462    """
463    with torch.profiler.record_function("FullyShardedDataParallel._post_forward"):
464        # For `fully_shard` + `checkpoint`, skip post-forward logic in the
465        # recomputed forward
466        if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE:
467            return output
468
469        state._exec_order_data.record_post_forward(handle)
470        if reshard_fn is not None:
471            reshard_fn(state, handle)
472        # Register pre-backward hooks to unshard the flat parameters for the
473        # gradient computation (if needed)
474        output = _register_pre_backward_hooks(state, module, output, handle)
475        state.training_state = TrainingState.IDLE
476        if handle:
477            handle._training_state = HandleTrainingState.IDLE
478        return output
479
480
481@no_type_check
482def _post_forward_reshard(
483    state: _FSDPState,
484    handle: FlatParamHandle,
485) -> None:
486    """Reshards parameters in the post-forward."""
487    if not handle:
488        return
489    # Do not free the root's parameters in the post-forward for `FULL_SHARD`
490    # with the intention that they are immediately used for backward
491    # computation (though this may not be true)
492    free_unsharded_flat_param = (
493        not state._is_root
494        and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
495    )
496    _reshard(state, handle, free_unsharded_flat_param)
497
498
499@no_type_check
500def _root_pre_forward(
501    state: _FSDPState,
502    module: nn.Module,
503    args,
504    kwargs,
505) -> None:
506    """
507    Runs pre-forward logic specific to the root FSDP instance, which should run
508    before any individual module's pre-forward. This starts with an attempt at
509    lazy initialization (which only runs non-vacuously once). Otherwise, if
510    this is called on a non-root FSDP instance, then it returns directly.
511
512    Args:
513        module (nn.Module): Module for which this logic tries to run. It may or
514            may not be the root. If not, then this method does not do anything.
515    """
516    with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"):
517        _lazy_init(state, module)
518        _p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
519        if not state._is_root:
520            # Always cast forward inputs in the root of this local FSDP unit for mixed
521            # precision, as this is where mixed precision could be configed.
522            # This is more useful for auto wrapping that is recommended in composable path.
523            # For manual wrapping, cast forward inputs on each local FSDP unit root will
524            # increase some overhead, so not turned on for model wrapper path right now where
525            # manual wrapping is more broadly used.
526            if _is_composable(state):
527                return _root_cast_forward_input(state, module, args, kwargs)
528            return args, kwargs
529
530        # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
531        # are in full precision and if we should cast them back to lower precision, which happens when
532        # exiting eval() mode.
533        handle = state._handle
534        if handle:
535            should_cast_buffers_to_full_prec = handle._force_full_precision
536        else:
537            should_cast_buffers_to_full_prec = True
538
539        if should_cast_buffers_to_full_prec:
540            _cast_buffers_to_dtype_and_device(
541                buffers=dict(module.named_buffers()).values(),
542                buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
543                device=state.compute_device,
544            )
545            # This flag is only set when we cast buffers to full precision, to avoid the
546            # CPU overhead that can stem from retrieving all buffers and their types in the
547            # following else branch.
548            state._needs_buffer_dtype_restore_check = True
549        elif getattr(state, "_needs_buffer_dtype_restore_check", False):
550            # Check if buffers are in full precision and we need to cast them
551            # back down.
552            (
553                buffers,
554                buffer_dtypes_for_computation,
555            ) = _get_buffers_and_dtypes_for_computation(state, module)
556            if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
557                if any(
558                    buffer.dtype != buffer_dtype_for_computation
559                    for buffer, buffer_dtype_for_computation in zip(
560                        buffers, buffer_dtypes_for_computation
561                    )
562                ):
563                    # Assume we have to cast everything if there is one mismatch
564                    _cast_buffers_to_dtype_and_device(
565                        buffers, buffer_dtypes_for_computation, state.compute_device
566                    )
567            # We don't have to check this again until we cast buffers to full precision again.
568            state._needs_buffer_dtype_restore_check = False
569
570        if state.forward_prefetch:
571            handles = []
572            for fsdp_state in state._all_fsdp_states:
573                if fsdp_state._handle:
574                    handles.append(fsdp_state._handle)
575            for handle in handles:
576                handle._needs_pre_forward_unshard = True
577                handle._prefetched = False
578        _wait_for_computation_stream(
579            state._device_handle.current_stream(),
580            state._unshard_stream,
581            state._pre_unshard_stream,
582        )
583        _reset_flat_param_grad_info_if_needed(state._all_handles)
584
585        # Prepares the forward inputs by moving them to ``compute_device``
586        # TODO: Do not use the side stream for tensor copies for now; investigate
587        # the perf with/without it.
588        with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"):
589            args_tuple, kwargs_tuple = _to_kwargs(
590                args, kwargs, state.compute_device, False
591            )
592        args = args_tuple[0]
593        kwargs = kwargs_tuple[0]
594
595        return _root_cast_forward_input(state, module, args, kwargs)
596
597
598@no_type_check
599def _root_cast_forward_input(
600    state: _FSDPState, module: torch.nn.Module, args, kwargs
601) -> Tuple[Any, Any]:
602    if state._handle:
603        force_full_precision = not state._handle._force_full_precision
604    else:
605        force_full_precision = True
606
607    should_cast_forward_inputs = (
608        (module.training or not state._use_full_prec_in_eval) and force_full_precision
609    ) and state.mixed_precision.cast_root_forward_inputs
610
611    if should_cast_forward_inputs:
612        input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
613        args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
614
615    return args, kwargs
616
617
618@no_type_check
619def _pre_backward_hook(
620    state: _FSDPState,
621    module: nn.Module,
622    handle: FlatParamHandle,
623    grad,
624    *unused: Any,
625) -> Any:
626    """
627    Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation.
628
629    Args:
630        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
631            Module]).
632    """
633    # Only run the pre-backward hook once per group of handles involved in the
634    # same module forward computation
635    if (
636        handle
637        and hasattr(handle, "_ran_pre_backward_hook")
638        and handle._ran_pre_backward_hook
639    ):
640        return grad
641
642    with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"):
643        # Queue the post-backward callback once for the root FSDP instance to
644        # attach it to the outermost backward graph task so that it is called
645        # after all backward calls complete
646        if state._is_root and not state._post_backward_callback_queued:
647            _register_post_backward_final_callback(state, module)
648            _reset_flat_param_grad_info_if_needed(state._all_handles)
649        elif handle:
650            allowed_states = [TrainingState.IDLE]
651            if _is_composable(state):
652                allowed_states.append(TrainingState.FORWARD_BACKWARD)
653            _assert_in_training_states(state, allowed_states)
654        state.training_state = TrainingState.FORWARD_BACKWARD
655        # Queueing the post-backward callback is the only logic that is not
656        # per-handle in the pre-backward hook, so we can return early here if
657        # there are no handles.
658        if not handle:
659            return grad
660        handle._training_state = HandleTrainingState.BACKWARD_PRE
661
662        if handle._needs_pre_backward_unshard:
663            # If the handles have been prefetched, then there is no need to
664            # call `_unshard()` again
665            if not handle._prefetched:
666                _unshard(
667                    state,
668                    handle,
669                    state._unshard_stream,
670                    state._pre_unshard_stream,
671                )
672            # Don't wait during trace
673            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
674                state._device_handle.current_stream().wait_stream(state._unshard_stream)
675
676        # Set this to `False` to ensure that a mistargeted prefetch does not
677        # actually unshard these handles
678        handle._needs_pre_backward_unshard = False
679        with torch.profiler.record_function(
680            "FullyShardedDataParallel._pre_backward_prefetch"
681        ):
682            _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
683        handle.prepare_gradient_for_backward()
684        handle._ran_pre_backward_hook = True
685        return grad
686
687
688@no_type_check
689@torch.no_grad()
690def _post_backward_hook(
691    state: _FSDPState,
692    handle: FlatParamHandle,
693    flat_param,
694    *unused: Any,
695):
696    """
697    Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
698
699    Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
700    unsharded gradient for the local batch.
701
702    Postcondition:
703    - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
704    unsharded gradient.
705    - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
706    gradient (accumulating with any existing gradient).
707    """
708    _log_post_backward_hook(state, handle, logger)
709    flat_param = handle.flat_param
710    flat_param._post_backward_called = True
711    with torch.autograd.profiler.record_function(
712        "FullyShardedDataParallel._post_backward_hook"
713    ):
714        _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
715        # For multiple applications of reentrant AC across submodules sharing
716        # the same `FlatParameter`, the post-backward hook may run multiple
717        # times in one backward, in which case we permit the state to already
718        # be in `BACKWARD_POST`.
719        _p_assert(
720            handle._training_state
721            in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
722            f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
723        )
724        handle._training_state = HandleTrainingState.BACKWARD_POST
725
726        if flat_param.grad is None:
727            return
728        if flat_param.grad.requires_grad:
729            raise RuntimeError("FSDP does not support gradients of gradients")
730
731        _post_backward_reshard(state, handle)
732        if not state._sync_gradients:
733            if handle._use_orig_params:
734                handle._use_unsharded_grad_views()
735            return
736
737        # Wait for all ops in the current stream (e.g. gradient computation) to
738        # finish before reduce-scattering the gradient
739        if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
740            state._post_backward_stream.wait_stream(
741                state._device_handle.current_stream()
742            )
743
744        with state._device_handle.stream(state._post_backward_stream):
745            autograd_computed_grad = flat_param.grad.data
746            if (
747                not _low_precision_hook_enabled(state)
748                and flat_param.grad.dtype != handle._reduce_dtype
749                # If we are forcing full precision but communicating grads
750                # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient.
751                and not handle._force_full_precision
752            ):
753                flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
754            if handle.uses_sharded_strategy:
755                _reduce_grad(state, handle)
756            else:
757                _reduce_grad_no_shard(state, handle)
758            # Since the unsharded gradient is produced in the computation
759            # stream and consumed in the post-backward stream, inform the
760            # caching allocator (before it goes out of scope)
761            _no_dispatch_record_stream(
762                autograd_computed_grad, state._post_backward_stream
763            )
764
765
766def _post_backward_reshard_only_hook(
767    state: _FSDPState,
768    handle: FlatParamHandle,
769    *unused: Any,
770) -> None:
771    with torch.profiler.record_function(
772        "FullyShardedDataParallel._post_backward_hook_reshard_only"
773    ):
774        # `_pre_backward_hook` may not get executed
775        # if forward output does not require grad
776        # overwrite IDLE state for post-backward prefetching
777        state.training_state = TrainingState.FORWARD_BACKWARD
778        handle._training_state = HandleTrainingState.BACKWARD_POST
779        _post_backward_reshard(state, handle)
780
781
782def _post_backward_reshard(
783    state: _FSDPState,
784    handle: FlatParamHandle,
785    *unused: Any,
786) -> None:
787    free_unsharded_flat_param = _should_free_in_backward(state, handle)
788    _reshard(state, handle, free_unsharded_flat_param)
789
790    # TODO: Post-backward prefetching does not support the multiple handles
791    # per module case since the post-backward hook runs per handle, not per
792    # group of handles.
793    with torch.profiler.record_function(
794        "FullyShardedDataParallel._post_backward_prefetch"
795    ):
796        _prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
797
798
799@no_type_check
800def _should_free_in_backward(
801    state: _FSDPState,
802    handle: FlatParamHandle,
803) -> bool:
804    """
805    Returns whether FSDP should free the unsharded flat parameter in the
806    post-backward or not.
807    """
808    if not handle.uses_sharded_strategy:
809        return False
810    # If not syncing gradients, then we do not free for strategies that do not
811    # reshard after forward as a *heuristic* to tradeoff higher memory for
812    # higher throughput.
813    return (
814        state._sync_gradients
815        or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
816    )
817
818
819@no_type_check
820def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
821    """
822    For sharded strategies, this runs gradient reduction, sharded gradient
823    accumulation if needed, and the post-reduction callback.
824    """
825    flat_param = handle.flat_param
826    uses_hybrid_sharded_strategy = handle._sharding_strategy in (
827        HandleShardingStrategy.HYBRID_SHARD,
828        HandleShardingStrategy._HYBRID_SHARD_ZERO2,
829    )
830    # We clear `.grad` to permit multiple backwards. This avoids a race where
831    # the second backward pass computation precedes ahead of the first backward
832    # pass reduction, which is possible since the reduction is issued in a
833    # separate stream and is async and would result in reducing the wrong
834    # gradient.
835    unsharded_grad = flat_param.grad.data
836    flat_param.grad = None
837    padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors(
838        state, unsharded_grad
839    )
840    if state._comm_hook is None:  # default path
841        _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor)
842        pg = (
843            handle._fake_process_group
844            if handle._use_fake_reduce
845            else state.process_group
846        )
847        dist.reduce_scatter_tensor(
848            new_sharded_grad,
849            padded_unsharded_grad,
850            group=pg,
851        )
852        if uses_hybrid_sharded_strategy:
853            # Don't wait during trace
854            if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
855                state._all_reduce_stream.wait_stream(state._post_backward_stream)
856            with state._device_handle.stream(state._all_reduce_stream):
857                # Since the new sharded gradient is produced in the post-
858                # backward stream and consumed in the all-reduce stream,
859                # inform the caching allocator
860                _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream)
861                dist.all_reduce(new_sharded_grad, group=state._inter_node_pg)
862                _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
863                grad_to_offload = _accumulate_sharded_grad(
864                    state, handle, new_sharded_grad
865                )
866                _post_reduce_grad_callback(state, handle, grad_to_offload)
867                return
868        _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor)
869    else:
870        state._comm_hook(
871            state._comm_hook_state, padded_unsharded_grad, new_sharded_grad
872        )
873        # NOTE: HSDP variants do not support communication hook.
874    grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad)
875    _post_reduce_grad_callback(state, handle, grad_to_offload)
876
877
878@no_type_check
879def _get_reduce_scatter_tensors(
880    state: _FSDPState, unsharded_grad: torch.Tensor
881) -> Tuple[torch.Tensor, torch.Tensor]:
882    """
883    Returns the input and output tensors to reduce-scatter, respectively.
884    """
885    chunks = list(unsharded_grad.chunk(state.world_size))
886    numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel()
887    padded_unsharded_grad = (
888        F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad
889    )
890    new_sharded_grad = torch.empty_like(chunks[0])  # padded
891    return padded_unsharded_grad, new_sharded_grad
892
893
894@no_type_check
895def _accumulate_sharded_grad(
896    state: _FSDPState,
897    handle: FlatParamHandle,
898    sharded_grad: torch.Tensor,
899) -> torch.Tensor:
900    """
901    Accumulates the reduce-scattered sharded gradient with any existing sharded
902    gradient if needed, returning the gradient to offload (if CPU offloading is
903    enabled).
904    """
905    flat_param = handle.flat_param
906    _cast_grad_to_param_dtype(state, sharded_grad, flat_param)
907    # Save the sharded gradient in `_saved_grad_shard` to support gradient
908    # accumulation -- for multiple backwards, the gradient reductions may
909    # happen in arbitrary order
910    accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
911    if accumulate_grad:
912        _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard)
913        flat_param._saved_grad_shard += sharded_grad
914    else:
915        flat_param._saved_grad_shard = sharded_grad
916    grad_to_offload = flat_param._saved_grad_shard
917    return grad_to_offload
918
919
920@no_type_check
921def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None:
922    """
923    For no-shard, this runs gradient reduction (which directly covers any
924    gradient accumulation implicitly) and the post-reduction callback.
925    """
926    flat_param = handle.flat_param
927    if state._comm_hook is None:  # default path
928        _div_if_needed(flat_param.grad, state._gradient_predivide_factor)
929        dist.all_reduce(flat_param.grad, group=state.process_group)
930        _div_if_needed(flat_param.grad, state._gradient_postdivide_factor)
931    else:
932        state._comm_hook(state._comm_hook_state, flat_param.grad)
933    # For `NO_SHARD`, we can keep the low precision gradients by simply
934    # omitting the cast altogether
935    if not handle._keep_low_precision_grads:
936        _cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
937    grad_to_offload = flat_param.grad.data
938    _post_reduce_grad_callback(state, handle, grad_to_offload)
939
940
941@no_type_check
942def _post_reduce_grad_callback(
943    state: _FSDPState,
944    handle: FlatParamHandle,
945    # Additional arguments needed for the callback logic
946    grad_to_offload: torch.Tensor,
947):
948    """
949    This callback captures any logic to run after the gradient reduction
950    finishes. Currently, this offloads the gradient to CPU if CPU offloading is
951    enabled and uses sharded gradient views if ``use_orig_params=True``.
952    """
953    _offload_grad(state, handle, grad_to_offload)
954    _post_backward_use_sharded_grad_views(handle)
955
956
957@no_type_check
958def _offload_grad(
959    state: _FSDPState,
960    handle: FlatParamHandle,
961    grad_to_offload: torch.Tensor,
962):
963    if not handle._offload_params:
964        return
965    # Offload the gradient to CPU to ensure parameters and gradients are on the
966    # same device as required by the optimizer
967    # TODO: Investigate why `NO_SHARD` breaks correctness when using
968    # `non_blocking=True` here.
969    # TODO (rohan-varma): When CPU offload and optimizer overlap,
970    # non_blocking=True won't work since the copy may have not finished before
971    # the optimizer step executes on CPU. If we want to use non-blocking=True
972    # here, we'll have to synchronize before using result on CPU.
973    non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward
974    handle.flat_param._cpu_grad.copy_(
975        grad_to_offload.detach(), non_blocking=non_blocking
976    )  # synchronized in the post-backward callback
977    # Since the gradient being offloaded may have been produced in the
978    # computation stream and is being consumed here in the post-backward
979    # stream, inform the caching allocator
980    _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream)
981
982
983@no_type_check
984def _post_backward_use_sharded_grad_views(handle: FlatParamHandle):
985    if not handle._use_orig_params:
986        return
987    # Since the handle's `FlatParameter` completed its gradient computation, we
988    # should reset the gradient noneness mask
989    handle._reset_is_grad_none()
990    # Delay using sharded gradient views until after the reduce-scatter instead
991    # of immediately after resharding
992    handle._use_sharded_grad_views()
993    if handle._has_optim_in_backward:
994        handle.prepare_gradient_for_optim()
995        for orig_param in handle.flat_param._params:
996            # Check for `None` gradient to filter parameters not in the rank
997            if orig_param.grad is not None and hasattr(
998                orig_param, "_in_backward_optimizers"
999            ):
1000                # TODO (rohan-varma): For CPU offload, this unfortunately
1001                # operates on CPU because the parameters and gradients have
1002                # already been offloaded. We should run this on GPU after
1003                # refactoring.
1004                for optim in orig_param._in_backward_optimizers:
1005                    optim.step()
1006
1007                optim.zero_grad(set_to_none=True)
1008        handle._reset_flat_param_grad_info_if_needed()
1009        if handle._offload_params:
1010            handle.flat_param._cpu_grad = None
1011
1012
1013def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None:
1014    if div_factor > 1:
1015        tensor.div_(div_factor)
1016
1017
1018@no_type_check
1019def _cast_grad_to_param_dtype(
1020    state: _FSDPState,
1021    sharded_grad: torch.Tensor,
1022    param: FlatParameter,
1023):
1024    """
1025    Casts ``sharded_grad`` back to the full parameter dtype so that the
1026    optimizer step runs with that dtype. This performs an actual cast if
1027    1. parameters were in reduced precision during the forward since then
1028    gradients would be in that reduced precision, or
1029    2. parameters were not in reduced precision but gradients were in
1030    reduced precision for communication.
1031    However, if a low precision communication hook is registered, then this
1032    dtype cast happens in the hook instead.
1033    """
1034    _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
1035    if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
1036        low_prec_grad_data = sharded_grad.data
1037        sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
1038        # Since for `NO_SHARD`, the gradient is produced in the computation
1039        # stream and consumed here in the post-backward stream, inform the
1040        # caching allocator; for the sharded strategies, the gradient is
1041        # produced in the post-backward stream, so this `record_stream()`
1042        # should be a no-op
1043        _no_dispatch_record_stream(
1044            low_prec_grad_data, state._device_handle.current_stream()
1045        )
1046
1047
1048def _check_grad_to_accumulate(
1049    new_sharded_grad: torch.Tensor,
1050    accumulated_grad: torch.Tensor,
1051) -> None:
1052    _p_assert(
1053        accumulated_grad.shape == new_sharded_grad.shape,
1054        "Shape mismatch when accumulating gradients: "
1055        f"existing gradient shape={accumulated_grad.shape} "
1056        f"new gradient shape={new_sharded_grad.shape}",
1057    )
1058    _p_assert(
1059        accumulated_grad.device == new_sharded_grad.device,
1060        "Device mismatch when accumulating gradients: "
1061        f"existing gradient device={accumulated_grad.device} "
1062        f"new gradient device={new_sharded_grad.device}",
1063    )
1064
1065
1066@no_type_check
1067def _low_precision_hook_enabled(state: _FSDPState) -> bool:
1068    return state._comm_hook in LOW_PRECISION_HOOKS
1069
1070
1071@no_type_check
1072@torch.no_grad()
1073def _post_backward_final_callback(
1074    state: _FSDPState,
1075    module: nn.Module,
1076):
1077    """
1078    This waits for the post-backward to finish and performs some final cleanup.
1079    This runs at the end of the entire backward pass and should only be called
1080    on the root FSDP instance.
1081    """
1082    _p_assert(
1083        state._is_root,
1084        "The post-backward callback should only be called on the root FSDP instance",
1085    )
1086    root_state = state
1087
1088    if root_state._sync_gradients:
1089        current_stream = state._device_handle.current_stream()
1090        # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish
1091        # since it currently runs in the post-backward stream. That can be
1092        # pushed to the next forward if run in a different stream
1093        current_stream.wait_stream(root_state._post_backward_stream)
1094        if root_state._all_reduce_stream is not current_stream:  # uses HSDP
1095            current_stream.wait_stream(root_state._all_reduce_stream)
1096        if root_state.cpu_offload.offload_params:
1097            # Wait for non-blocking GPU -> CPU sharded gradient copies from the
1098            # post-backward hooks to finish explicitly since CPU gradients do
1099            # not automatically synchronize with the GPU
1100            state._device_handle.current_stream().synchronize()
1101    root_state._exec_order_data.next_iter()
1102
1103    for fsdp_state in state._all_fsdp_states:
1104        _catch_all_reshard(fsdp_state)
1105        _finalize_params(fsdp_state)
1106        fsdp_state.training_state = TrainingState.IDLE
1107        handle = fsdp_state._handle
1108        if handle:
1109            handle._ran_pre_backward_hook = False
1110            handle._needs_pre_backward_unshard = False
1111            handle._post_forward_index = None
1112            handle._training_state = HandleTrainingState.IDLE
1113            handle._prefetched = False
1114    # Reset for cases like one forward and multiple backwards
1115    root_state._post_backward_callback_queued = False
1116
1117
1118@no_type_check
1119def _catch_all_reshard(
1120    state: _FSDPState,
1121) -> None:
1122    """
1123    Reshards the parameters that may not have been resharded in the
1124    post-backward hook. This can happen when a module's output is used in the
1125    forward pass, meaning that its pre-backward hook runs (unsharding the
1126    parameter), but the post-backward hook does not run because the output was
1127    not jused in the loss computation corresponding to this backward pass.
1128    """
1129    # Wrap with a try-except to provide a more informative traceback if an
1130    # error is raised
1131    try:
1132        if state._handle:
1133            # TODO: This already-resharded check is brittle:
1134            # https://github.com/pytorch/pytorch/issues/83956
1135            already_resharded = (
1136                state._handle.flat_param.data_ptr()
1137                == state._handle.flat_param._local_shard.data_ptr()
1138                # If FSDP skipped using sharded views, then the flat parameter
1139                # still points to the sharded data, so we need to reshard to
1140                # use sharded views
1141                and not state._handle._skipped_use_sharded_views
1142            )
1143            if already_resharded:
1144                return
1145            free_unsharded_flat_param = _should_free_in_backward(state, state._handle)
1146            _reshard(state, state._handle, free_unsharded_flat_param)
1147    except Exception as e:
1148        _p_assert(
1149            False,
1150            f"Got exception in the catch-all reshard for {state}: {str(e)}",
1151            raise_assertion_error=False,
1152        )
1153        raise e
1154
1155
1156@no_type_check
1157def _finalize_params(
1158    state: _FSDPState,
1159) -> None:
1160    """Finalizes the parameters before the next iteration."""
1161    handle = state._handle
1162    if not handle:
1163        return
1164    flat_param = handle.flat_param
1165    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1166        if hasattr(flat_param, "_post_backward_hook_handle"):
1167            pbhs_handle = flat_param._post_backward_hook_handle
1168            pbhs_handle.remove()
1169            del flat_param._post_backward_hook_handle
1170    else:
1171        if hasattr(flat_param, "_post_backward_hook_state"):
1172            post_backward_hook_state_len = len(flat_param._post_backward_hook_state)
1173            expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1
1174            _p_assert(
1175                post_backward_hook_state_len == expected_post_backward_hook_state_len,
1176                f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
1177            )
1178            flat_param._post_backward_hook_state[-1].remove()
1179            delattr(flat_param, "_post_backward_hook_state")
1180    if flat_param.requires_grad:
1181        if not state._sync_gradients:
1182            # Preserve the gradient accumulation state if not synchronizing
1183            # gradients: `.grad` remains the unsharded gradient  from prior
1184            # `no_sync()` iterations, and `_saved_grad_shard` remains the
1185            # sharded gradient from the last synchronized iteration
1186            return
1187        if not handle._has_optim_in_backward:
1188            handle.prepare_gradient_for_optim()
1189        _p_assert(
1190            hasattr(flat_param, "_post_backward_called"),
1191            "Expects `_post_backward_called` to be set on the `FlatParameter`",
1192        )
1193        flat_param._post_backward_called = False
1194
1195
1196@no_type_check
1197def _prefetch_handle(
1198    state: _FSDPState,
1199    current_handle: Optional[FlatParamHandle],
1200    prefetch_mode: _PrefetchMode,
1201) -> None:
1202    """
1203    Prefetches the next handles if needed (without synchronization). An empty
1204    handles key cannot prefetch.
1205    """
1206    if not current_handle:
1207        return
1208    handle = _get_handle_to_prefetch(state, current_handle)
1209    if not handle:
1210        return
1211    # Temporarily emulate the training state while calling `_unshard` to
1212    # ensure the correct `as_params` for `_use_unsharded_views()`
1213    prev_training_state = handle._training_state
1214    if prefetch_mode == _PrefetchMode.BACKWARD:
1215        handle._training_state = HandleTrainingState.BACKWARD_PRE
1216    elif prefetch_mode == _PrefetchMode.FORWARD:
1217        handle._training_state = HandleTrainingState.FORWARD
1218    else:
1219        raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}")
1220    # Prefetch the next set of handles without synchronizing to allow
1221    # the sync to happen as late as possible to maximize overlap
1222    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
1223    handle._training_state = prev_training_state
1224    handle._prefetched = True
1225
1226
1227@no_type_check
1228def _get_handle_to_prefetch(
1229    state: _FSDPState,
1230    current_handle: FlatParamHandle,
1231) -> FlatParamHandle:
1232    """
1233    Returns a :class:`list` of the handles keys to prefetch for the next
1234    module(s), where ``current_handle`` represents the current module.
1235
1236    "Prefetching" refers to running the unshard logic early (without
1237    synchronization), and the "next" modules depend on the recorded execution
1238    order and the current training state.
1239    """
1240    training_state = _get_training_state(current_handle)
1241    valid_training_states = (
1242        HandleTrainingState.BACKWARD_PRE,
1243        HandleTrainingState.BACKWARD_POST,
1244        HandleTrainingState.FORWARD,
1245    )
1246    _p_assert(
1247        training_state in valid_training_states,
1248        f"Prefetching is only supported in {valid_training_states} but "
1249        f"currently in {training_state}",
1250    )
1251    eod = state._exec_order_data
1252    target_handle: Optional[FlatParamHandle] = None
1253    if (
1254        training_state == HandleTrainingState.BACKWARD_PRE
1255        and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
1256    ) or (
1257        training_state == HandleTrainingState.BACKWARD_POST
1258        and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
1259    ):
1260        target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle)
1261        if (
1262            target_handle_candidate
1263            and target_handle_candidate._needs_pre_backward_unshard
1264            and not target_handle_candidate._prefetched
1265        ):
1266            target_handle = target_handle_candidate
1267        else:
1268            target_handle = None
1269    elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
1270        target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle)
1271        if (
1272            target_handle_candidate
1273            and target_handle_candidate._needs_pre_forward_unshard
1274            and not target_handle_candidate._prefetched
1275        ):
1276            target_handle = target_handle_candidate
1277        else:
1278            target_handle = None
1279
1280    return target_handle
1281
1282
1283def _get_training_state(
1284    handle: FlatParamHandle,
1285) -> HandleTrainingState:
1286    """Returns the training state of the handles in ``handle``."""
1287    _p_assert(handle, "Expects a non-empty handle")
1288    return handle._training_state
1289
1290
1291@no_type_check
1292def _register_pre_forward_hook(
1293    state: _FSDPState,
1294    module: nn.Module,
1295) -> None:
1296    """
1297    Registers a pre-forward hook on ``module``.
1298    """
1299    for forward_handle in state._pre_forward_handles:
1300        forward_handle.remove()
1301    state._pre_forward_handles.clear()
1302    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
1303    hook = functools.partial(
1304        _pre_forward, state, module_param_handle, _pre_forward_unshard
1305    )
1306    state._pre_forward_handles.append(
1307        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
1308    )
1309
1310
1311@no_type_check
1312def _register_post_forward_hook(
1313    state: _FSDPState,
1314    module: nn.Module,
1315) -> None:
1316    """
1317    Registers a post-forward hook on ``module``. Even if the module has no
1318    handles, we should register the hook since it will register the module's
1319    pre-backward hook.
1320    """
1321    for forward_handle in state._post_forward_handles:
1322        forward_handle.remove()
1323    state._post_forward_handles.clear()
1324    module_param_handle = state._fully_sharded_module_to_handle.get(module, None)
1325    hook = functools.partial(
1326        _post_forward,
1327        state,
1328        module_param_handle,
1329        _post_forward_reshard,
1330    )
1331    state._post_forward_handles.append(module.register_forward_hook(hook))
1332
1333
1334@no_type_check
1335def _register_root_pre_forward_hook(
1336    state: _FSDPState,
1337    module: nn.Module,
1338):
1339    """
1340    Registers root pre-forward hook on ``module``, which should be the local
1341    FSDP root.
1342
1343    NOTE: For the current composable FSDP design, we have each application of
1344    ``fully_shard()`` to a module to indicate that that module is the local
1345    FSDP root. We may remove this assumption in the future, in which case we
1346    will need to register this root pre-forward hook on any candidate module
1347    that may be the local FSDP root.
1348    """
1349    for forward_handle in state._root_pre_forward_handles:
1350        forward_handle.remove()
1351    state._root_pre_forward_handles.clear()
1352    hook = functools.partial(_root_pre_forward, state)
1353    state._root_pre_forward_handles.append(
1354        module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
1355    )
1356
1357
1358@no_type_check
1359def _register_pre_backward_hooks(
1360    state: _FSDPState,
1361    module: nn.Module,
1362    outputs: Any,
1363    handle: FlatParamHandle,
1364) -> None:
1365    """
1366    Registers pre-backward hooks on the tensors that require gradients in the
1367    forward pass outputs ``outputs``, which were computed using the
1368    ``FlatParameter`` s of ``handles``.
1369
1370    Args:
1371        module (nn.Module): Fully sharded module (see [Note: Fully Sharded
1372            Module]).
1373
1374    Returns:
1375        Forward pass outputs with pre-backward hooks registered to tensors that
1376        require gradients.
1377    """
1378    # If there is no gradient computation, then there is no need for
1379    # pre-backward logic
1380    if not torch.is_grad_enabled():
1381        return outputs
1382    if state._is_root:
1383        state._post_backward_callback_queued = False  # only defined on the root
1384
1385    if handle:
1386        handle._needs_pre_backward_unshard = False
1387        # Since these handles' `FlatParameter`s participated in a forward, we
1388        # conservatively assume that they will be used in the backward
1389        handle._ran_pre_backward_hook = False
1390
1391    def _register_hook(t: torch.Tensor) -> torch.Tensor:
1392        if t.requires_grad:
1393            t.register_hook(
1394                torch.utils.hooks.unserializable_hook(
1395                    functools.partial(_pre_backward_hook, state, module, handle)
1396                )
1397            )
1398            if handle:
1399                handle._needs_pre_backward_unshard = True
1400        return t
1401
1402    return _apply_to_tensors(_register_hook, outputs)
1403
1404
1405def _register_post_backward_hook(
1406    state: _FSDPState,
1407    handle: Optional[FlatParamHandle],
1408) -> None:
1409    """
1410    Registers post-backward hooks on the ``FlatParameter`` s'
1411    ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
1412
1413    The ``AccumulateGrad`` object represents the last function that finalizes
1414    the ``FlatParameter`` 's gradient, so it only runs after its entire
1415    gradient computation has finished.
1416
1417    We register the post-backward hook only once in the *first* forward that a
1418    ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
1419    object being preserved through multiple forwards.
1420
1421    NOTE: We follow this heuristic to prefer the *first* forward to target the
1422    parameter mixed precision case, where there are *separate*
1423    ``AccumulateGrad`` objects across the different forwards. (Without
1424    parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If
1425    we instead prefer the *last* forward, then the hook runs early.
1426    """
1427    # If there is no gradient computation, then there is no need for
1428    # post-backward logic
1429    if not torch.is_grad_enabled():
1430        return
1431    if not handle:
1432        return
1433    flat_param = handle.flat_param
1434
1435    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1436        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
1437        if already_registered or not flat_param.requires_grad:
1438            return
1439        hook = functools.partial(_post_backward_hook, state, handle)
1440        hook_handle = flat_param.register_post_accumulate_grad_hook(hook)
1441        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined]
1442    else:
1443        already_registered = hasattr(flat_param, "_post_backward_hook_state")
1444        if already_registered or not flat_param.requires_grad:
1445            return
1446        # Get the `AccumulateGrad` object
1447        temp_flat_param = flat_param.expand_as(flat_param)
1448        _p_assert(
1449            temp_flat_param.grad_fn is not None,
1450            "The `grad_fn` is needed to access the `AccumulateGrad` and "
1451            "register the post-backward hook",
1452        )
1453        acc_grad = temp_flat_param.grad_fn.next_functions[0][0]  # type: ignore[union-attr]
1454        assert acc_grad is not None
1455        hook_handle = acc_grad.register_hook(
1456            functools.partial(_post_backward_hook, state, handle)
1457        )
1458        flat_param._post_backward_hook_state = (acc_grad, hook_handle)  # type: ignore[attr-defined]
1459
1460
1461def _register_post_backward_reshard_only_hook(
1462    state: _FSDPState,
1463    handle: Optional[FlatParamHandle],
1464    args: Tuple[Any, ...],
1465    kwargs: Dict[str, Any],
1466) -> None:
1467    """
1468    Registers post-backward hooks to reshard flat parameters that do not
1469    require gradient. We register these using multi-post-grad hooks on the
1470    input activations to ensure that all gradients that may depend on the
1471    parameters have been computed before resharding.
1472    """
1473    # If there is no gradient computation, then there is no need for
1474    # post-backward logic
1475    if not torch.is_grad_enabled():
1476        return
1477    # Construct `inp_tensors` lazily to avoid CPU overhead in typical case
1478    # where each flat parameter requires gradient
1479    inp_tensors: Optional[List[torch.Tensor]] = None
1480    if not handle:
1481        return
1482    flat_param = handle.flat_param
1483
1484    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1485        already_registered = hasattr(flat_param, "_post_backward_hook_handle")
1486    else:
1487        already_registered = hasattr(flat_param, "_post_backward_hook_state")
1488
1489    if already_registered or flat_param.requires_grad:
1490        return
1491    if inp_tensors is None:
1492        args_flat = pytree.arg_tree_leaves(*args, **kwargs)
1493        inp_tensors = [
1494            obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
1495        ]
1496    assert inp_tensors is not None  # mypy
1497    hook_handle = register_multi_grad_hook(
1498        inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
1499    )
1500    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1501        flat_param._post_backward_hook_handle = hook_handle  # type: ignore[attr-defined, assignment]
1502    else:
1503        flat_param._post_backward_hook_state = (hook_handle,)  # type: ignore[attr-defined, assignment]
1504
1505
1506@no_type_check
1507def _register_post_backward_final_callback(
1508    state: _FSDPState, module: nn.Module
1509) -> None:
1510    """
1511    Registers the post-backward final callback that runs at the end of the
1512    backward pass. This should be called from the root FSDP instance at the
1513    beginning of the pre-backward.
1514    """
1515    _p_assert(
1516        state._is_root,
1517        "Only the root FSDP instance should register the post-backward callback",
1518    )
1519    if state._post_backward_callback_queued:
1520        return
1521    _assert_in_training_states(state, [TrainingState.IDLE])
1522    # Trace does not need this callback
1523    if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
1524        state._post_backward_callback_queued = True
1525        Variable._execution_engine.queue_callback(
1526            functools.partial(_post_backward_final_callback, state, module)
1527        )
1528
1529
1530def _wait_for_computation_stream(
1531    computation_stream: torch.Stream,
1532    unshard_stream: torch.Stream,
1533    pre_unshard_stream: torch.Stream,
1534):
1535    """
1536    Has the unshard and pre-unshard streams wait for the computation stream.
1537    For example, this should be called in the FSDP root's pre-forward to
1538    respect optimizer step computation.
1539    """
1540    # Tracing does not need to wait
1541    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
1542        return
1543    unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
1544    # Having the pre-all-gather stream wait for the current stream even if we
1545    # do not leverage the pre-all-gather stream is tolerable since this only
1546    # runs once per iteration
1547    pre_unshard_stream.wait_stream(computation_stream)  # type: ignore[attr-defined]
1548
1549
1550def _reset_flat_param_grad_info_if_needed(
1551    handles: List[FlatParamHandle],
1552):
1553    """
1554    Clears the original parameters' gradients if needed. This method's CPU
1555    overhead is minimal, so we may call it throughout FSDP methods, which serve
1556    as callsites to free the gradient memory earlier.
1557    """
1558    if not isinstance(handles, list):
1559        handles = [handles]
1560    for handle in handles:
1561        if handle._use_orig_params:
1562            handle._reset_flat_param_grad_info_if_needed()
1563
1564
1565@no_type_check
1566def _get_buffers_and_dtypes_for_computation(
1567    state: _FSDPState,
1568    root_module: nn.Module,
1569) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
1570    """
1571    Returns all buffers in the module tree rooted at ``root_module`` and a
1572    corresponding list of the buffer dtypes for computation. Each buffer dtype
1573    is either ``None`` if buffer mixed precision is not enabled or the buffer
1574    low precision dtype otherwise.
1575    """
1576    _p_assert(state._is_root, "Expects the root to cast buffers")
1577    buffers: List[torch.Tensor] = []
1578    buffer_dtypes: List[Optional[torch.dtype]] = []
1579    visited_buffers: Set[torch.Tensor] = set()
1580    # Traverse the FSDP states bottom-up so that we prefer the owning FSDP
1581    # instance's mixed precision setting for each buffer
1582    fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
1583        root_module
1584    )
1585    for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
1586        for buffer_name, buffer in fsdp_module.named_buffers():
1587            if buffer in visited_buffers:
1588                continue
1589            visited_buffers.add(buffer)
1590            if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names:
1591                continue
1592            buffers.append(buffer)
1593            buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
1594    assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
1595    return buffers, buffer_dtypes
1596
1597
1598@no_type_check
1599def _get_orig_buffer_dtypes(
1600    state: _FSDPState,
1601    buffer_names: List[str],
1602) -> List[torch.dtype]:
1603    """
1604    Returns the original buffer types of the given buffer names.
1605    """
1606    buffer_dtypes: List[torch.dtype] = []
1607    for buffer_name in buffer_names:
1608        _p_assert(
1609            buffer_name in state._buffer_name_to_orig_dtype,
1610            f"{buffer_name} is missing from pre-computed dict on rank "
1611            f"{state.rank}, which only has keys "
1612            f"{state._buffer_name_to_orig_dtype.keys()}",
1613        )
1614        buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
1615    return buffer_dtypes
1616
1617
1618def _cast_buffers_to_dtype_and_device(
1619    buffers: List[torch.Tensor],
1620    buffer_dtypes: List[Optional[torch.dtype]],
1621    device: torch.device,
1622) -> None:
1623    """
1624    Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
1625    to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
1626    corresponding buffer is only moved to ``device``.
1627    """
1628    _p_assert(
1629        buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
1630        f"Expects `buffers` and `buffer_dtypes` to have the same length if "
1631        f"`buffer_dtypes` is specified but got {len(buffers)} and "
1632        f"{len(buffer_dtypes)}",
1633    )
1634    for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
1635        if not torch.is_floating_point(buffer) or buffer_dtype is None:
1636            buffer.data = buffer.to(device=device)
1637        else:
1638            buffer.data = buffer.to(device=device, dtype=buffer_dtype)
1639