xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_init_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import itertools
4import os
5import warnings
6from typing import (
7    Any,
8    Callable,
9    Deque,
10    Dict,
11    Generator,
12    Iterable,
13    Iterator,
14    List,
15    no_type_check,
16    Optional,
17    Set,
18    Tuple,
19    TYPE_CHECKING,
20    Union,
21)
22
23import torch
24import torch.distributed as dist
25import torch.distributed.fsdp._exec_order_utils as exec_order_utils
26import torch.distributed.fsdp._traversal_utils as traversal_utils
27import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
28import torch.nn as nn
29from torch.distributed.algorithms._comm_hooks import default_hooks
30from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
31from torch.distributed.distributed_c10d import _get_default_group
32from torch.distributed.fsdp._common_utils import (
33    _FSDPDeviceHandle,
34    _FSDPState,
35    _get_module_fsdp_state,
36    _is_fsdp_flattened,
37    _named_parameters_with_duplicates,
38    clean_tensor_name,
39    TrainingState,
40)
41from torch.distributed.fsdp._flat_param import (
42    _FSDP_USE_FULL_PREC_IN_EVAL,
43    FlatParameter,
44    FlatParamHandle,
45    HandleShardingStrategy,
46)
47from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
48from torch.distributed.fsdp.api import (
49    BackwardPrefetch,
50    CPUOffload,
51    FullOptimStateDictConfig,
52    FullStateDictConfig,
53    MixedPrecision,
54    ShardingStrategy,
55    StateDictConfig,
56    StateDictType,
57)
58from torch.distributed.fsdp.wrap import _Policy
59from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
60from torch.distributed.utils import _sync_params_and_buffers
61from torch.utils._python_dispatch import is_traceable_wrapper_subclass
62
63
64if TYPE_CHECKING:
65    from torch.utils.hooks import RemovableHandle
66
67_TORCHDISTX_AVAIL = True
68try:
69    from torchdistx import deferred_init, fake  # type: ignore[import]
70except ImportError:
71    _TORCHDISTX_AVAIL = False
72
73PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
74FSDP_SYNCED = "_fsdp_synced"
75# Specification of process groups for hybrid sharding strategies.
76HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
77# Overall specification of process group.
78ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
79
80
81# TODO (awgu): Refactor this later
82SHARDING_STRATEGY_MAP = {
83    ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
84    ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
85    ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
86    ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
87    ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
88}
89HYBRID_SHARDING_STRATEGIES = [
90    ShardingStrategy.HYBRID_SHARD,
91    ShardingStrategy._HYBRID_SHARD_ZERO2,
92]
93NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
94    ShardingStrategy.SHARD_GRAD_OP,
95    ShardingStrategy._HYBRID_SHARD_ZERO2,
96)
97
98
99# NOTE: Since non-self attributes cannot be type annotated, several attributes
100# on `state` are defined first as local variables before being assigned.
101
102
103@no_type_check
104def _init_process_group_state(
105    state: _FSDPState,
106    process_group: ProcessGroupType,
107    sharding_strategy: ShardingStrategy,
108    policy: Optional[_Policy],
109    device_mesh: Optional[DeviceMesh] = None,
110) -> _FSDPState:
111    if process_group is not None and device_mesh is not None:
112        raise ValueError(
113            "Cannot pass both process_group and device_mesh at the "
114            "same time. Please just pass only one of them."
115        )
116    is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES
117    if is_hybrid_strategy:
118        if process_group is None and policy is None and device_mesh is None:
119            # Raise an error here, since this is manual wrapping with no process group
120            # passed in, there is no way to ensure all wrapped FSDP instances use the same
121            # process groups.
122            raise ValueError(
123                f"Manual wrapping with {sharding_strategy} "
124                "requires explicit specification of process group or device_mesh."
125            )
126        else:
127            state = _init_process_group_state_for_hybrid_shard(
128                state, process_group, device_mesh
129            )
130    else:
131        if device_mesh:
132            state._device_mesh = device_mesh
133            state.process_group = device_mesh.get_group(mesh_dim=0)
134        else:
135            state.process_group = (
136                process_group if process_group is not None else _get_default_group()
137            )
138
139    state.rank = state.process_group.rank()
140    state.world_size = state.process_group.size()
141    data_parallel_world_size = state.world_size
142    if is_hybrid_strategy:
143        data_parallel_world_size *= state._inter_node_pg.size()
144    state._gradient_predivide_factor = (
145        default_hooks.DefaultState._get_gradient_predivide_factor(
146            data_parallel_world_size
147        )
148    )
149    state._gradient_postdivide_factor = (
150        data_parallel_world_size / state._gradient_predivide_factor
151    )
152    return state
153
154
155@no_type_check
156def _init_process_group_state_for_hybrid_shard(
157    state: _FSDPState,
158    process_group: ProcessGroupType,
159    device_mesh: DeviceMesh,
160) -> _FSDPState:
161    if device_mesh:
162        if _is_valid_hybrid_shard_device_mesh(device_mesh):
163            state._device_mesh = device_mesh
164            # We currently only allow _inter_node_pg to be the outermost dimension, and the
165            # process_group(intra_node) to be the innermost dimension.
166            state._inter_node_pg = device_mesh.get_group(mesh_dim=0)
167            state.process_group = device_mesh.get_group(mesh_dim=1)
168        else:
169            raise ValueError(
170                f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}"
171            )
172    elif process_group is None:
173        default_group = _get_default_group()
174        intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
175            default_group, state._device_handle.device_count()
176        )
177        # we shard across intra-node
178        state.process_group = intra_node_group
179        # save _inter_node_pg to allreduce across.
180        state._inter_node_pg = inter_node_group
181    else:
182        # Check type and assign state.process_group and state._inter_node_pg.
183        if _is_valid_hybrid_shard_pg_type(process_group):
184            # Assuming that user passed in as intra node group and inter node group
185            # as documented.
186            state.process_group, state._inter_node_pg = process_group
187        else:
188            raise ValueError(
189                "Expected process_group to be passed in as either None or "
190                f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
191            )
192    # Create state for allreduce
193    state._inter_node_state = _get_default_comm_hook_state(
194        process_group=state._inter_node_pg,
195    )
196    return state
197
198
199@no_type_check
200def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
201    return (
202        isinstance(process_group, tuple)
203        and len(process_group) == 2
204        and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
205    )
206
207
208@no_type_check
209def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
210    return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2
211
212
213@no_type_check
214def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup:
215    """
216    Return a process group across the current node.
217
218    For example, given each row is a distinct node:
219    0  1  2  3  4  5  6  7
220    8  9 10 11 12 13 14 15
221    This API would return an intra-node subgroup across
222    [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank.
223    For example, rank 3 would get [0, 1, ..., 7].
224    """
225    intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node)
226    return intra_node_subgroup
227
228
229@no_type_check
230def _init_inter_node_process_group(
231    global_process_group: dist.ProcessGroup,
232    num_devices_per_node: int,
233) -> dist.ProcessGroup:
234    """
235    Return an inter-node process group where each contained rank has the same local rank.
236
237    For example, given each row is a distinct node:
238    0  1  2  3  4  5  6  7
239    8  9 10 11 12 13 14 15
240    This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth
241    depending on the process's rank. For example, rank 1 would get [1, 9], rank 5
242    would get [5, 13].
243    """
244    # the inter-node pg that is returned
245    inter_node_pg = None
246    sharding_backend = dist.get_backend(global_process_group)
247    world_size = dist.get_world_size(global_process_group)
248    # Assuming fully homogeneous setup
249    num_nodes = world_size // num_devices_per_node
250    my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node
251    for local_rank in range(num_devices_per_node):
252        ranks_for_inter_group = [
253            local_rank + (i * num_devices_per_node) for i in range(num_nodes)
254        ]
255        # every rank always needs to call dist.new_group
256        grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
257        if local_rank == my_local_rank:
258            inter_node_pg = grp
259
260    assert (
261        inter_node_pg is not None
262    ), f"{my_local_rank} expected to assign inter-node pg, but did not"
263    return inter_node_pg
264
265
266def _init_intra_and_inter_node_groups(
267    global_process_group: dist.ProcessGroup,
268    num_devices_per_node: int,
269) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
270    """
271    Initialize intra and inter-node process groups and return the ones corresponding to this process's rank.
272
273    This function can be used to initialize process groups for ``HYBRID_SHARD`` or
274    ``_HYBRID_SHARD_ZERO2`` in FSDP.
275    This function assumes each node has an equal number of CUDA-enabled devices.
276    Returns:
277        Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
278    """
279    return (
280        _init_intra_node_process_group(num_devices_per_node),
281        _init_inter_node_process_group(global_process_group, num_devices_per_node),
282    )
283
284
285@no_type_check
286def _init_ignored_module_states(
287    state: _FSDPState,
288    module: nn.Module,
289    ignored_modules: Optional[Iterable[torch.nn.Module]],
290    ignored_states: Union[
291        Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]
292    ] = None,
293) -> _FSDPState:
294    if ignored_modules is not None and ignored_states is not None:
295        raise ValueError(
296            "Cannot pass both ignored_modules and ignored_states at the "
297            "same time. Please just pass ignored_states."
298        )
299    ignored_parameters = None
300    passed_as_ignored_states = ignored_states is not None
301    if passed_as_ignored_states:
302        ignored_states_list = list(ignored_states)
303        _check_ignored_states(ignored_states_list, True)
304    else:
305        ignored_states_list = []
306        _check_ignored_states(
307            list(ignored_modules) if ignored_modules is not None else [], False
308        )
309    if len(ignored_states_list) > 0:
310        if isinstance(ignored_states_list[0], nn.Parameter):
311            ignored_parameters = ignored_states_list
312        else:
313            ignored_modules = ignored_states_list
314    state._ignored_modules = _get_ignored_modules(module, ignored_modules)
315    state._ignored_params = _get_ignored_params(
316        module,
317        state._ignored_modules,
318        ignored_parameters,
319    )
320    state._ignored_buffer_names = _get_ignored_buffer_names(
321        module,
322        state._ignored_modules,
323    )
324    # TODO: FSDP's contract for buffers is not well-defined. They are
325    # implicitly ignored for most functionality since they are not sharded;
326    # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed
327    # precision). We should formalize this contract and decide if we need to
328    # compute and store `_ignored_buffers`.
329    return state
330
331
332def _check_ignored_states(
333    ignored_states: List[Any], passed_as_ignored_states: bool
334) -> None:
335    """
336    Check that the ignored states are uniformly parameters or uniformly modules.
337
338    We may remove this check in the future if we permit mixing.
339    """
340    if len(ignored_states) == 0:
341        return
342    if passed_as_ignored_states:
343        all_params = all(isinstance(state, nn.Parameter) for state in ignored_states)
344        all_modules = all(isinstance(state, nn.Module) for state in ignored_states)
345        if not all_params and not all_modules:
346            # Sort for consistent ordering for unit test regex matching
347            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
348            raise ValueError(
349                "ignored_states expects all nn.Parameter or all nn.Module list "
350                f"elements but got types {sorted_types}"
351            )
352    else:
353        if not all(isinstance(state, nn.Module) for state in ignored_states):
354            sorted_types = sorted({type(state) for state in ignored_states}, key=repr)
355            raise ValueError(
356                "ignored_modules expects nn.Module list elements but got "
357                f"types {sorted_types}"
358            )
359
360
361@no_type_check
362def _init_device_handle(
363    state: _FSDPState,
364    module: nn.Module,
365    ignored_params: Set[nn.Parameter],
366    device_id: Optional[Union[int, torch.device]],
367) -> _FSDPState:
368    """
369    Determine device handle used for initializing FSDP.
370
371    If a device is specified by ``device_id``,
372    then returns device handle corresponds to that device type. Otherwise, If the
373    module is already on a non-CPU device, then the device type is that non-CPU device type.
374    If the module is on CPU or meta, then the device type is the current accelerator device.
375    See the :ref:`Accelerators<accelerators>` for details.
376
377
378    This method will be called once ignored paramters was determined, as the device handle maybe needed
379    for other initialization.
380    """
381    determined_device = None
382    if device_id is not None:
383        determined_device = (
384            device_id
385            if isinstance(device_id, torch.device)
386            else torch.device(device_id)
387        )
388    if determined_device is None:
389        for param in _get_orig_params(module, ignored_params):
390            if param.device.type in {"cpu", "meta"}:
391                continue
392            if determined_device is None:
393                determined_device = param.device
394            else:
395                if param.device.type != determined_device.type:
396                    raise RuntimeError(
397                        f"FSDP does not support modules with different device types "
398                        f"but got params on {determined_device.type} and {param.device.type}"
399                    )
400        determined_device = determined_device or torch._C._get_accelerator()
401        if determined_device.type == "cpu":
402            raise RuntimeError(
403                "FSDP needs a non-CPU accelerator device, but no accelerator device is detected."
404            )
405
406    state._device_handle = _FSDPDeviceHandle.from_device(determined_device)
407    return state
408
409
410@no_type_check
411def _init_buffer_state(
412    state: _FSDPState,
413    module: nn.Module,
414) -> _FSDPState:
415    state._buffer_names = _get_buffer_names(module)
416    # Save a mapping from clean fully-qualified buffer name (starting from
417    # `module`) to its original dtype for restoring that dtype during model
418    # checkpointing when buffer mixed precision is enabled. The names should
419    # be clean since the casting happens in a `summon_full_params()` context.
420    _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
421    for buffer_name, buffer in module.named_buffers():
422        buffer_name = clean_tensor_name(buffer_name)
423        _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
424    state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
425    return state
426
427
428@no_type_check
429def _init_core_state(
430    state: _FSDPState,
431    sharding_strategy: Optional[ShardingStrategy],
432    mixed_precision: Optional[MixedPrecision],
433    cpu_offload: Optional[CPUOffload],
434    limit_all_gathers: bool,
435    use_orig_params: bool,
436    backward_prefetch_limit: int,
437    forward_prefetch_limit: int,
438) -> _FSDPState:
439    # We clamp the strategy to `NO_SHARD` for world size of 1 since they are
440    # currently functionally equivalent. This may change if/when we integrate
441    # FSDP with MoE.
442    if state.world_size == 1:
443        if sharding_strategy != ShardingStrategy.NO_SHARD:
444            warnings.warn(
445                "FSDP is switching to use `NO_SHARD` instead of "
446                f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
447                "the world size is 1."
448            )
449        sharding_strategy = ShardingStrategy.NO_SHARD
450    elif sharding_strategy == ShardingStrategy.NO_SHARD:
451        warnings.warn(
452            "The `NO_SHARD` sharding strategy is deprecated. If having issues, "
453            "please use `DistributedDataParallel` instead.",
454            FutureWarning,
455            # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and
456            # level 3 is from the true caller
457            stacklevel=3,
458        )
459    state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
460    state.mixed_precision = mixed_precision or MixedPrecision()
461    if mixed_precision is not None:
462        torch._C._log_api_usage_once(
463            f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}"
464        )
465    state._use_full_prec_in_eval = (
466        os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1"
467    )
468    state.cpu_offload = cpu_offload or CPUOffload()
469    state.limit_all_gathers = limit_all_gathers
470    state._use_orig_params = use_orig_params
471    state.training_state = TrainingState.IDLE
472    state._is_root = None
473    state._free_event_queue = _FreeEventQueue()
474    state._debug_level = dist.get_debug_level()
475    state._exec_order_data = exec_order_utils._ExecOrderData(
476        state._debug_level,
477        backward_prefetch_limit,
478        forward_prefetch_limit,
479    )
480    state._unshard_event = None
481    # Mapping from fully sharded module to the handles it is responsible to
482    # unshard and reshard (see [Note: Fully Sharded Module])
483    _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {}
484    state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle
485    # Invariant: `state.params` contains exactly the `FlatParameter`s of the
486    # handles in `state._handle`
487    _handle: Optional[FlatParamHandle] = None
488    state._handle = _handle
489    params: List[FlatParameter] = []
490    state.params = params
491    return state
492
493
494@no_type_check
495def _init_runtime_state(
496    state: _FSDPState,
497) -> _FSDPState:
498    _root_pre_forward_handles: List[RemovableHandle] = []
499    state._root_pre_forward_handles = _root_pre_forward_handles
500    _pre_forward_handles: List[RemovableHandle] = []
501    state._pre_forward_handles = _pre_forward_handles
502    _post_forward_handles: List[RemovableHandle] = []
503    state._post_forward_handles = _post_forward_handles
504    state._sync_gradients = True
505    state._comm_hook = None
506    state._comm_hook_state = None
507    # Used to prevent running the pre-backward hook multiple times
508    return state
509
510
511@no_type_check
512def _init_prefetching_state(
513    state: _FSDPState,
514    backward_prefetch: BackwardPrefetch,
515    forward_prefetch: bool,
516) -> _FSDPState:
517    state.backward_prefetch = backward_prefetch
518    state.forward_prefetch = forward_prefetch
519    # The data structures use tuples of handles to generalize over the case
520    # where a module's forward involves multiple handles.
521    return state
522
523
524@no_type_check
525def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
526    # TODO: we need to add additional check once we support FSDP + PiPPy.
527    # This check is currently sufficient, since we only support FSDP + TP.
528    root_mesh = _mesh_resources.get_root_mesh(device_mesh)
529    # if a root mesh is not the same as device_mesh,
530    # meaning the device_mesh is sliced out from the root mesh.
531    if device_mesh and root_mesh != state._device_mesh:
532        state._fsdp_extension = DTensorExtensions(state._device_handle)
533    else:
534        # We need to explicilty set _fsdp_extension to None.
535        # Otherwise, we will run into an infinite recursion when getting the attribute.
536        state._fsdp_extension = None
537    return state
538
539
540@no_type_check
541def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
542    state._state_dict_type = StateDictType.FULL_STATE_DICT
543    state_dict_config: StateDictConfig = FullStateDictConfig()
544    state._optim_state_dict_config = FullOptimStateDictConfig()
545    state._state_dict_config = state_dict_config
546    unshard_params_ctx: Dict[nn.Module, Generator] = {}
547    state._unshard_params_ctx = unshard_params_ctx
548
549    return state
550
551
552def _verify_managed_params(module: nn.Module, params: List[nn.Parameter]) -> None:
553    """
554    Verify if the parameters are accepted by FSDP. The only restriction now
555    is that the parameter cannot be a scalar tensor (param.shape == []).
556    """
557    for param in params:
558        if len(param.shape) == 0:
559            param_name = ""
560            for name, param_ in module.named_parameters():
561                if param is param_:
562                    param_name = name
563                    break
564            assert param_name
565            raise ValueError(
566                "FSDP doesn't support salar parameters. "
567                f"Change {param_name} to a 1D tensor with numel equal to 1."
568            )
569
570
571@no_type_check
572def _init_param_handle_from_module(
573    state: _FSDPState,
574    fully_sharded_module: nn.Module,
575    device_id: Optional[Union[int, torch.device]],
576    param_init_fn: Optional[Callable[[nn.Module], None]],
577    sync_module_states: bool,
578) -> _FSDPState:
579    """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``."""
580    _check_single_device_module(fully_sharded_module, state._ignored_params, device_id)
581    device_from_device_id = _get_device_from_device_id(
582        device_id, state.rank, state._device_handle
583    )
584    is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
585        fully_sharded_module, state._ignored_params, state._ignored_modules
586    )
587    # Materialize the module if needed
588    if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
589        _materialize_with_param_init_fn(
590            fully_sharded_module, param_init_fn, state._ignored_modules
591        )
592    elif is_meta_module:
593        _materialize_meta_module(
594            fully_sharded_module,
595            device_id,
596            state._ignored_modules,
597            state._device_handle,
598        )
599    elif is_torchdistX_deferred_init:
600        deferred_init.materialize_module(
601            fully_sharded_module,
602            check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None
603            and submodule not in state._ignored_modules,
604        )
605
606    ignored_buffers = {
607        buffer
608        for ignored_module in state._ignored_modules
609        for buffer in ignored_module.buffers()
610    }
611
612    _move_module_to_device(
613        fully_sharded_module,
614        state._ignored_params,
615        ignored_buffers,
616        device_from_device_id,
617    )
618    state.compute_device = _get_compute_device(
619        fully_sharded_module,
620        state._ignored_params,
621        device_from_device_id,
622        state.rank,
623        state._device_handle,
624    )
625
626    managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
627    _verify_managed_params(fully_sharded_module, managed_params)
628    if sync_module_states:
629        _sync_module_params_and_buffers(
630            fully_sharded_module, managed_params, state.process_group
631        )
632        if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
633            _sync_module_params_and_buffers(
634                fully_sharded_module, managed_params, state._inter_node_pg
635            )
636    _init_param_handle_from_params(state, managed_params, fully_sharded_module)
637    return state
638
639
640@no_type_check
641def _init_param_handle_from_params(
642    state: _FSDPState,
643    params: List[nn.Parameter],
644    fully_sharded_module: nn.Module,
645):
646    if len(params) == 0:
647        return
648    handle = FlatParamHandle(
649        params,
650        fully_sharded_module,
651        state.compute_device,
652        SHARDING_STRATEGY_MAP[state.sharding_strategy],
653        state.cpu_offload.offload_params,
654        state.mixed_precision.param_dtype,
655        state.mixed_precision.reduce_dtype,
656        state.mixed_precision.keep_low_precision_grads,
657        state.process_group,
658        state._use_orig_params,
659        fsdp_extension=state._fsdp_extension,
660    )
661    handle.shard()
662    assert not state._handle
663    state.params.append(handle.flat_param)
664    state._handle = handle
665    state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
666    cpu_device = torch.device("cpu")
667    if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
668        handle.flat_param_to(cpu_device)
669
670
671def _get_ignored_modules(
672    root_module: nn.Module,
673    _ignored_modules: Optional[Iterable[torch.nn.Module]],
674) -> Set[nn.Module]:
675    """
676    Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances.
677
678    Return the modules contained in their module
679    subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
680    already-computed ignored modules are included.
681
682    ``_ignored_modules`` represents the argument passed by the user to FSDP.
683    """
684    msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
685    try:
686        ignored_root_modules = (
687            set(_ignored_modules) if _ignored_modules is not None else set()
688        )
689    except TypeError as e:
690        raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
691    for module in ignored_root_modules:
692        if not isinstance(module, torch.nn.Module):
693            raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
694        if _get_module_fsdp_state(module):
695            # TODO: We may relax this by taking the FSDP instance's wrapped
696            # module to provide more flexibility to the user.
697            raise ValueError("`ignored_modules` should not include FSDP modules")
698    # Treat modules that cannot compose with `fully_shard` as ignored modules,
699    # meaning that their subtrees are ignored
700    for module in root_module.modules():
701        if not traversal_utils._composable(module):
702            ignored_root_modules.add(module)
703    # NOTE: Even if `ignored_root_modules` is empty, do not return early so
704    # that this FSDP instance can get any ignored modules from its children.
705
706    # Include child modules and exclude nested FSDP modules themselves
707    ignored_modules = {
708        child
709        for module in ignored_root_modules
710        for child in module.modules()
711        if not isinstance(child, fsdp_file.FullyShardedDataParallel)
712    }
713    if root_module in ignored_modules:
714        warnings.warn(
715            "Trying to ignore the top-level module passed into the FSDP "
716            "constructor itself will result in all parameters being "
717            f"ignored and is not well-supported: {module}"
718        )
719    # Include nested FSDP modules' ignored modules
720    for submodule in root_module.modules():
721        optional_fsdp_state = _get_module_fsdp_state(submodule)
722        if optional_fsdp_state is not None:
723            assert hasattr(optional_fsdp_state, "_ignored_modules")
724            ignored_modules.update(optional_fsdp_state._ignored_modules)
725    return ignored_modules
726
727
728def _get_ignored_params(
729    root_module: torch.nn.Module,
730    ignored_modules: Set[torch.nn.Module],
731    ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
732) -> Set[torch.nn.Parameter]:
733    """
734    Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``.
735
736    :class:`FlatParameter` s are excluded from the result.
737    """
738    all_ignored_params: Set[torch.nn.Parameter] = set()
739
740    params_in_ignored_modules = {
741        p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
742    }
743
744    all_ignored_params.update(params_in_ignored_modules)
745
746    if ignored_parameters is not None:
747        params_in_ignored_parameters = {
748            p for p in ignored_parameters if not _is_fsdp_flattened(p)
749        }
750        all_ignored_params.update(params_in_ignored_parameters)
751
752    # Always include nested FSDP modules' ignored parameters
753    for submodule in root_module.modules():
754        optional_fsdp_state = _get_module_fsdp_state(submodule)
755        if optional_fsdp_state is not None:
756            assert hasattr(optional_fsdp_state, "_ignored_params")
757            all_ignored_params.update(optional_fsdp_state._ignored_params)
758
759    return all_ignored_params
760
761
762def _get_ignored_buffer_names(
763    root_module: torch.nn.Module,
764    ignored_modules: Set[torch.nn.Module],
765) -> Set[str]:
766    """Return the cleaned buffer FQNs in ``ignored_modules``."""
767    all_ignored_buffer_names: Set[str] = set()
768
769    buffers_in_ignored_modules = {
770        buffer for m in ignored_modules for buffer in m.buffers()
771    }
772
773    all_ignored_buffer_names.update(
774        {
775            clean_tensor_name(buffer_name)
776            for buffer_name, buffer in root_module.named_buffers()
777            if buffer in buffers_in_ignored_modules
778        }
779    )
780
781    # Always include nested FSDP modules' ignored buffer names
782    for submodule in root_module.modules():
783        optional_fsdp_state = _get_module_fsdp_state(submodule)
784        if optional_fsdp_state is not None:
785            assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
786            all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
787
788    return all_ignored_buffer_names
789
790
791def _get_buffer_names(root_module: nn.Module) -> Set[str]:
792    """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`."""
793    return {
794        clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
795    }
796
797
798def _check_single_device_module(
799    module: nn.Module,
800    ignored_params: Set[nn.Parameter],
801    device_id: Optional[Union[int, torch.device]],
802) -> None:
803    """
804    Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``.
805
806    Thus, after this method, the
807    module must be either fully on the CPU or fully on a non-CPU device.
808    """
809    devices = {param.device for param in _get_orig_params(module, ignored_params)}
810    # We allow module to be partially on CPU and partially on GPU if device_id is not
811    # None, since the device_id arg will result in the CPU portion being moved to
812    # GPU. This is useful in cases where part of the module may be parallelized
813    # by another algorithm and may already be on GPU. We'd like to enforce device_id
814    # to not be None, otherwise we'd flatten parameters in a mixed module which is
815    # not supported.
816    if len(devices) == 2 and torch.device("cpu") in devices:
817        if device_id is None:
818            raise RuntimeError(
819                "To support a module with both CPU and GPU params, "
820                "please pass in device_id argument."
821            )
822    elif len(devices) > 1:
823        raise RuntimeError(
824            f"FSDP only supports single device modules but got params on {devices}"
825        )
826
827
828def _get_device_from_device_id(
829    device_id: Optional[Union[int, torch.device]],
830    rank: int,
831    device_handle: _FSDPDeviceHandle,
832) -> Optional[torch.device]:
833    """
834    Return a ``torch.device`` for the specified ``device_id``.
835
836    Processes ``device_id`` and returns either the corresponding device or
837    ``None`` if ``device_id`` is ``None``.
838    """
839    if device_id is None:
840        return None
841    device = (
842        device_id if isinstance(device_id, torch.device) else torch.device(device_id)
843    )
844    if device.type != "cpu" and device.index is None:
845        warnings.warn(
846            f"FSDP got the argument `device_id` {device_id} on rank "
847            f"{rank}, which does not have an explicit index. "
848            f"FSDP will use the current device {device_handle.current_device()}. "
849            f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` "
850            "before FSDP initialization or pass in the explicit device "
851            "index as the `device_id` argument."
852        )
853        device = torch.device(device_handle.current_device())
854    return device
855
856
857def _need_to_materialize_module(
858    module: nn.Module,
859    ignored_params: Set[nn.Parameter],
860    ignored_modules: Set[nn.Module],
861) -> Tuple[bool, bool]:
862    """
863    Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization.
864
865    At most of the returned bools can
866    be ``True``. If either is ``True``, then ``module`` needs to be
867    materialized.
868    """
869    managed_params = list(_get_orig_params(module, ignored_params))
870    is_meta_module = any(param.is_meta for param in managed_params)
871    # TODO: We need to establish a contract for FSDP and buffers. For now, we
872    # skip checking for meta buffers from ignored modules. We should consider
873    # refactoring the initialization holistically to avoid so many traversals.
874    for submodule in module.modules():
875        if submodule in ignored_modules:
876            continue
877        for buf in submodule.buffers(recurse=False):
878            is_meta_module |= buf.is_meta
879    is_torchdistX_deferred_init = (
880        not is_meta_module
881        and _TORCHDISTX_AVAIL
882        and any(fake.is_fake(param) for param in managed_params)
883    )
884    return is_meta_module, is_torchdistX_deferred_init
885
886
887def _materialize_with_param_init_fn(
888    root_module: nn.Module,
889    param_init_fn: Callable[[nn.Module], None],
890    ignored_modules: Set[nn.Module],
891) -> None:
892    if not callable(param_init_fn):
893        raise ValueError(
894            f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
895        )
896    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
897    for module in modules_to_materialize:
898        param_init_fn(module)
899
900
901def _materialize_meta_module(
902    root_module: nn.Module,
903    device_from_device_id: Optional[torch.device],
904    ignored_modules: Set[nn.Module],
905    device_handle: _FSDPDeviceHandle,
906):
907    # Run default meta device initialization
908    materialization_device = device_from_device_id or torch.device(
909        device_handle.current_device()
910    )
911    modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules)
912    module = None
913    try:
914        # Assume that each module's `reset_parameters()` only initializes its
915        # own parameters and not those of its children
916        with torch.no_grad():
917            for module in modules_to_materialize:
918                # As a contract to the user, only call `reset_parameters()` if
919                # the module has directly managed parameters/buffers
920                module_state_iter = itertools.chain(
921                    module.parameters(recurse=False), module.buffers(recurse=False)
922                )
923                has_module_states = len(list(module_state_iter)) > 0
924                if has_module_states:
925                    module.to_empty(device=materialization_device, recurse=False)
926                    module.reset_parameters()  # type: ignore[operator]
927    except BaseException as e:
928        warnings.warn(
929            "Unable to call `reset_parameters()` for module on meta "
930            f"device with error {str(e)}. Please ensure that your module of"
931            f"type {type(module)} implements a `reset_parameters()` method."  # type: ignore[possibly-undefined]
932        )
933        raise e
934
935
936def _get_modules_to_materialize(
937    root_module: nn.Module, ignored_modules: Set[nn.Module]
938) -> List[nn.Module]:
939    # Run BFS to collect the modules to materialize via `reset_parameters()`,
940    # stopping at any module with FSDP already applied or at ignored modules.
941    modules_to_materialize: List[nn.Module] = []
942    queue = collections.deque([root_module])
943    visited_modules: Set[nn.Module] = {root_module}
944    while queue:
945        module = queue.popleft()
946        modules_to_materialize.append(module)
947        for child_module in module.children():
948            if (
949                child_module not in visited_modules
950                and _get_module_fsdp_state(child_module) is None
951                and child_module not in ignored_modules
952            ):
953                visited_modules.add(child_module)
954                queue.append(child_module)
955    return modules_to_materialize
956
957
958def _move_module_to_device(
959    module: nn.Module,
960    ignored_params: Set[nn.Parameter],
961    ignored_buffers: Set[torch.Tensor],
962    device_from_device_id: Optional[torch.device],
963) -> None:
964    """
965    Move ``module`` depending on ``device_from_device_id`` and its current device.
966
967    This includes moving ignored modules' parameters.
968
969    - If ``device_from_device_id`` is not ``None``, then this moves
970    ``module`` to the device.
971    - If ``device_from_device_id`` is ``None``, then this does not move
972    ``module`` but warns the user if it is on CPU.
973
974    Precondition: ``_check_single_device_module()``.
975    """
976    cpu_device = torch.device("cpu")
977    if device_from_device_id is not None:
978        # BFS from `module` without traversing any nested FSDP instances to
979        # collect the parameters/buffers that have not yet been managed
980        queue: Deque[nn.Module] = collections.deque()
981        queue.append(module)
982        params: List[nn.Parameter] = []
983        buffers: List[torch.Tensor] = []
984        while queue:
985            curr_module = queue.popleft()
986            # NOTE: We include a check to only move parameters/buffers that are
987            # on CPU device. If they are on a CUDA device different from the
988            # one specified by `device_id`, then this does NOT move them. This
989            # is so that we can raise an error in `_get_compute_device()`.
990            params.extend(
991                param
992                for param in curr_module.parameters(recurse=False)
993                if param.device == cpu_device
994            )
995            buffers.extend(
996                buffer
997                for buffer in curr_module.buffers(recurse=False)
998                if buffer.device == cpu_device
999            )
1000            for submodule in curr_module.children():
1001                if not isinstance(submodule, fsdp_file.FullyShardedDataParallel):
1002                    queue.append(submodule)
1003        params_to_move = [p for p in params if p not in ignored_params]
1004        bufs_to_move = [p for p in buffers if p not in ignored_buffers]
1005        _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
1006        return
1007    param = next(_get_orig_params(module, ignored_params), None)
1008    if param is not None and param.device == cpu_device:
1009        _warn_cpu_init()
1010
1011
1012def _move_states_to_device(
1013    params: List[nn.Parameter],
1014    buffers: List[torch.Tensor],
1015    device_from_device_id: Optional[torch.device],
1016) -> None:
1017    """
1018    Move states to the specified device.
1019
1020    Precondition: ``_check_single_device_module()`` and module's parameters and
1021    buffers have been materialized if needed.
1022    """
1023    if len(params) == 0 and len(buffers) == 0:
1024        return
1025    if len(params) > 0:
1026        current_device = params[0].device
1027    elif len(buffers) > 0:
1028        current_device = buffers[0].device
1029    cpu_device = torch.device("cpu")
1030    if device_from_device_id is not None:
1031        # Move the parameters and buffers like the `.data` code path in
1032        # `nn.Module._apply()`, which underlies `nn.Module.to()`
1033        for param in params:
1034            with torch.no_grad():
1035                param.data = param.to(device_from_device_id)
1036                if param.grad is not None:
1037                    param.grad.data = param.grad.to(device_from_device_id)
1038        for buffer in buffers:
1039            buffer.data = buffer.to(device_from_device_id)
1040    elif current_device == cpu_device:  # type: ignore[possibly-undefined]
1041        _warn_cpu_init()
1042
1043
1044def _warn_cpu_init():
1045    warnings.warn(
1046        "The passed-in `module` is on CPU and will thus have FSDP's sharding "
1047        "initialization run on CPU, which may be slower than on GPU. We "
1048        "recommend passing in the `device_id` argument for FSDP to move "
1049        "`module` to GPU for the sharding initialization. `module` must also "
1050        "be on GPU device to work with the `sync_module_states=True` flag "
1051        "since that requires GPU communication."
1052    )
1053
1054
1055def _get_compute_device(
1056    module: nn.Module,
1057    ignored_params: Set[nn.Parameter],
1058    device_from_device_id: Optional[torch.device],
1059    rank: int,
1060    device_handle: _FSDPDeviceHandle,
1061) -> torch.device:
1062    """
1063    Determine and return this FSDP instance's compute device.
1064
1065    If the module is already on a non-CPU device, then the compute device is that non-CPU
1066    device. If the module is on CPU, then the compute device is the current
1067    device.
1068
1069    Since this method should be called after materializing the module, any
1070    non-CPU device should not be meta device. For now, the compute device is
1071    always a CUDA or CUDA-like device with its explicit index.
1072
1073    Precondition: ``_check_single_device_module()`` and
1074    ``_move_module_to_device()``.
1075    """
1076    param = next(_get_orig_params(module, ignored_params), None)
1077    if param is not None and param.device.type != "cpu":
1078        compute_device = param.device  # Determined by model param placement
1079    else:
1080        compute_device = torch.device(device_handle.current_device())
1081    if device_from_device_id is not None and compute_device != device_from_device_id:
1082        raise ValueError(
1083            f"Inconsistent compute device and `device_id` on rank {rank}: "
1084            f"{compute_device} vs {device_from_device_id}"
1085        )
1086    return compute_device
1087
1088
1089# TODO: See how to deprecate!
1090def _sync_module_params_and_buffers(
1091    module: nn.Module,
1092    params: List[nn.Parameter],
1093    process_group: dist.ProcessGroup,
1094) -> None:
1095    """
1096    Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
1097
1098    Precondition: ``sync_module_states == True`` and ``self.process_group`` has
1099    been set.
1100    """
1101    module_states: List[torch.Tensor] = []
1102    for buffer in module.buffers():
1103        # Avoid re-synchronizing buffers in case of nested wrapping
1104        if not getattr(buffer, FSDP_SYNCED, False):
1105            setattr(buffer, FSDP_SYNCED, True)
1106            detached_buffer = buffer.detach()
1107            if is_traceable_wrapper_subclass(detached_buffer):
1108                # NOTE: Here we assume no nested subclasses, at most one level of subclass
1109                # in both model's buffers and params
1110                attrs, _ = detached_buffer.__tensor_flatten__()  # type: ignore[attr-defined]
1111                inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
1112                module_states.extend(inner_buffers)
1113            else:
1114                module_states.append(detached_buffer)
1115
1116    for param in params:
1117        detached_param = param.detach()
1118        if is_traceable_wrapper_subclass(detached_param):
1119            attrs, _ = detached_param.__tensor_flatten__()  # type: ignore[attr-defined]
1120            inner_params = [getattr(detached_param, attr) for attr in attrs]
1121            module_states.extend(inner_params)
1122        else:
1123            module_states.append(detached_param)
1124
1125    _check_module_states_for_sync_module_states(module_states)
1126    _sync_params_and_buffers(
1127        process_group,
1128        module_states,
1129        PARAM_BROADCAST_BUCKET_SIZE,
1130        src=0,
1131    )
1132
1133
1134def _check_module_states_for_sync_module_states(
1135    module_states: List[torch.Tensor],
1136) -> None:
1137    if module_states and any(
1138        tensor.device == torch.device("cpu") for tensor in module_states
1139    ):
1140        raise ValueError(
1141            "The module has CPU parameters or buffers when `sync_module_states=True`, "
1142            "which requires them to be on GPU. Please specify the `device_id` argument "
1143            "or move the module to GPU before passing it to FSDP."
1144        )
1145
1146
1147def _get_orig_params(
1148    module: nn.Module,
1149    ignored_params: Set[nn.Parameter],
1150) -> Iterator[nn.Parameter]:
1151    """
1152    Return an iterator over the original parameters in ``module``.
1153
1154    The iterator does not return
1155    the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
1156    present due to nested FSDP wrapping), or any original parameters already
1157    flattened (only relevant when ``use_orig_params=True``).
1158    """
1159    param_gen = module.parameters()
1160    try:
1161        while True:
1162            param = next(param_gen)
1163            if param not in ignored_params and not _is_fsdp_flattened(param):
1164                yield param
1165    except StopIteration:
1166        pass
1167
1168
1169def _check_orig_params_flattened(
1170    fsdp_module,
1171    ignored_params: Set[nn.Parameter],
1172) -> None:
1173    """
1174    Check that original parameters in ``fsdp_module`` have been flattened.
1175
1176    The flattened parameters are made
1177    invisible to ``named_parameters()`` for the module hierarchy rooted at
1178    ``fsdp_module``. This should be called as a sanity check after flattening
1179    the wrapped module's parameters.
1180    """
1181    for param_name, param in _named_parameters_with_duplicates(fsdp_module):
1182        if param not in ignored_params and not _is_fsdp_flattened(param):
1183            raise RuntimeError(
1184                f"Found an unflattened parameter: {param_name}; "
1185                f"{param.size()} {param.__class__}"
1186            )
1187
1188
1189def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
1190    return (
1191        default_hooks.allreduce_hook
1192        if sharding_strategy == ShardingStrategy.NO_SHARD
1193        else default_hooks.reduce_scatter_hook
1194    )
1195
1196
1197def _get_default_comm_hook_state(
1198    process_group: dist.ProcessGroup,
1199) -> default_hooks.DefaultState:
1200    return default_hooks.DefaultState(process_group=process_group)
1201