xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/_common_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This file includes private common utilities for FSDP.
4"""
5import logging
6import traceback
7import warnings
8import weakref
9from enum import auto, Enum
10from functools import partial
11from typing import (
12    Any,
13    Callable,
14    cast,
15    Dict,
16    Generator,
17    Iterable,
18    List,
19    no_type_check,
20    Optional,
21    Set,
22    Tuple,
23    Type,
24    TYPE_CHECKING,
25)
26
27import torch
28import torch.distributed as dist
29import torch.distributed.fsdp._flat_param as flat_param_file
30import torch.nn as nn
31from torch.distributed._composable_state import _get_module_state, _State
32from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
33    _CHECKPOINT_PREFIX,
34)
35from torch.distributed.utils import _apply_to_tensors
36from torch.utils._mode_utils import no_dispatch
37
38from .api import (
39    FullOptimStateDictConfig,
40    FullStateDictConfig,
41    OptimStateDictConfig,
42    ShardingStrategy,
43    StateDictConfig,
44    StateDictType,
45)
46
47
48if TYPE_CHECKING:
49    from torch.distributed.device_mesh import DeviceMesh
50    from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
51
52    from ._flat_param import FlatParamHandle
53
54FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
55FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
56FSDP_FLATTENED = "_fsdp_flattened"
57
58# Save a global mapping from module to its input tensor dtype to be populated
59# during the forward pre-hook and consumed in the forward post-hook when
60# overriding a module's mixed precision
61# NOTE: We currently take the last input tensor's dtype in the case of multiple
62# floating-point input tensors, which may be incorrect. However, since there is
63# not a 1:1 correspondence between input and output tensors, we must use *some*
64# heuristic like this to predict the desired output dtype.
65_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
66
67
68class _FSDPDeviceHandle:
69    """
70    This is a simple abstraction for FSDP computing devices,
71    which enables custom backends that implement CUDA-like
72    semantics to be integrated with FSDP.
73    """
74
75    def __init__(self, device: torch.device, backend: Any = None):
76        if backend is None:
77            try:
78                self.__backend = getattr(torch, device.type)
79                self.__device = device
80            except AttributeError as exc:
81                raise AttributeError(
82                    f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'."
83                ) from exc
84        else:
85            self.__backend = backend
86
87    @classmethod
88    def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle":
89        """
90        Return a device handle corresponding to the device, and through this handle,
91        operations with the same semantics as CUDA can be performed on the device.
92        Just return torch.cuda if the device is cuda to make attribute-access faster.
93        Custom backend must first register a module with the same name with {device.type} on torch.
94        """
95        if device.type == "cuda":
96            return cast(_FSDPDeviceHandle, torch.cuda)
97        elif device.type == "mtia":
98            return cast(_FSDPDeviceHandle, torch.mtia)
99        return cls(device)
100
101    def __getattr__(self, __name: str) -> Any:
102        try:
103            return getattr(self.__backend, __name)
104        except AttributeError as exc:
105            raise AttributeError(
106                f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'"
107            ) from exc
108
109
110class _UninitializedDeviceHandle(_FSDPDeviceHandle):
111    def __init__(self) -> None:
112        pass
113
114    def __getattribute__(self, __name: str) -> Any:
115        raise RuntimeError("Trying to use an uninitialized device handle.")
116
117
118class _FSDPState(_State):
119    def __init__(self) -> None:
120        # TODO: Move all the attributes to this class to enable typing for
121        # FSDP/fully_shard.
122        self._ignored_modules: Set[nn.Module] = set()
123        self._ignored_params: Set[nn.Parameter] = set()
124        # Buffer names are cleaned (without wrapper prefixes)
125        self._ignored_buffer_names: Set[str] = set()
126        self.process_group: Optional[dist.ProcessGroup] = None
127        self.rank: int = -1
128        self.world_size: int = -1
129        self._device_mesh: Optional[DeviceMesh] = None
130        self.sharding_strategy = ShardingStrategy.FULL_SHARD
131        self._use_orig_params: bool = False
132        self.training_state = TrainingState.IDLE
133        self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
134        self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
135        self._state_dict_config: StateDictConfig = FullStateDictConfig()
136        self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
137        self._is_root: Optional[bool] = None
138        self._handle: Optional[flat_param_file.FlatParamHandle] = None
139        self._fully_sharded_module_to_handle: Dict[
140            nn.Module, Optional[flat_param_file.FlatParamHandle]
141        ] = {}
142        self.compute_device: Optional[torch.device] = None
143        self._gradient_predivide_factor: int = 0
144        self._gradient_postdivide_factor: int = 0
145        self._comm_hook: Optional[Callable] = None
146        self._comm_hook_state: Optional[Any] = None
147        self._unshard_event: Optional[torch.Event] = None
148        # Abstract device handle for fsdp compute device. For now,
149        # the compute device must implement cuda semantics used by fsdp
150        self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle()
151        # All following attributes should only be used for root states:
152        # Save these static lists to avoid the repeated tree traversals
153        self._all_fsdp_states: List[_FSDPState] = []
154        self._all_handles: List[flat_param_file.FlatParamHandle] = []
155        self._fsdp_extension: Optional[FSDPExtensions] = None
156
157
158def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
159    state = _get_module_state(module)
160    if state is None or not isinstance(state, _FSDPState):
161        return None
162    return state
163
164
165def _get_module_fsdp_state_if_fully_sharded_module(
166    module: nn.Module,
167) -> Optional[_FSDPState]:
168    state = _get_module_fsdp_state(module)
169    if state is None:
170        return None
171    if state == module:  # FullyShardedDataParallel module case.
172        return state
173    if module in state._fully_sharded_module_to_handle:  # fully_shard case.
174        return state
175    return None
176
177
178class TrainingState(Enum):
179    """
180    An enum that indicates the state of a ``FullyShardedDataParallel` instance.
181    """
182
183    IDLE = auto()
184    FORWARD_BACKWARD = auto()
185    SUMMON_FULL_PARAMS = auto()
186
187
188class HandleTrainingState(Enum):
189    """
190    An enum that indicates the state of a ``FlatParamHandle`.
191    """
192
193    IDLE = auto()
194    FORWARD = auto()
195    BACKWARD_PRE = auto()
196    BACKWARD_POST = auto()
197    SUMMON_FULL_PARAMS = auto()
198
199
200def _is_composable(state: _FSDPState):
201    # TODO: This is a temporary hack for differentiate between code paths.
202    return not isinstance(state, nn.Module)
203
204
205@no_type_check
206def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]:
207    """
208    Returns the ``FlatParamHandle`` s corresponding to ``module``. This is
209    the handle that contains some parameter in ``module``.
210    """
211    if _is_composable(state):
212        # A valid FSDP state may have no managed parameters and hence no
213        # handles, meaning no entry in `_fully_sharded_module_to_handles`
214        if state._handle is None:
215            return None
216        assert (
217            module in state._fully_sharded_module_to_handle
218        ), f"Expects a fully sharded module but got {module} on rank {state.rank}"
219        return state._fully_sharded_module_to_handle[module]
220    else:
221        # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
222        return module._handle
223
224
225@no_type_check
226def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
227    """Returns if ``module`` has parameters managed by FSDP."""
228    return _module_handle(state, module) is not None
229
230
231def _get_sharding_strategy(handle):
232    """
233    Returns the sharding strategy of the handle.
234    """
235    return handle._sharding_strategy if handle else None
236
237
238def clean_tensor_name(tensor_name: str) -> str:
239    """
240    Cleans the parameter or buffer name by removing any module wrapper
241    prefixes.
242    """
243    tensor_name = tensor_name.replace(FSDP_PREFIX, "")
244    # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
245    # it couples `CheckpointWrapper` and FSDP and also does not scale for more
246    # module wrappers.
247    tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
248    return tensor_name
249
250
251def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
252    """
253    Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
254    avoid re-flattening it during nested construction.
255    """
256    setattr(tensor, FSDP_FLATTENED, True)
257
258
259def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
260    """Returns if ``tensor`` has been marked as flattened by FSDP."""
261    return getattr(tensor, FSDP_FLATTENED, False)
262
263
264def _named_parameters_with_duplicates(
265    module: nn.Module, **kwargs: Any
266) -> List[Tuple[str, nn.Parameter]]:
267    """
268    This API is required as some modules overwrite `named_parameters()` but do not support
269    `remove_duplicate`.
270    """
271    assert (
272        "remove_duplicate" not in kwargs
273    ), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
274    kwargs["remove_duplicate"] = False
275    try:
276        ret = list(module.named_parameters(**kwargs))
277    except AssertionError as e:
278        kwargs.pop("remove_duplicate")
279        ret = list(module.named_parameters(**kwargs))
280    return ret
281
282
283def _get_param_to_fqns(
284    model: torch.nn.Module,
285    dedup_shared_params: bool = True,
286) -> Dict[nn.Parameter, List[str]]:
287    """
288    Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here,
289    we use canonical to mean the fully-qualified name assigned to the parameter
290    based on its position in the original nn.Module hierarchy before any wrapper
291    or parallelism has been applied to it. This is in contrast to FQNs that may be
292    generated after parallelisms or wrappers have been applied to the model.
293
294    Each normal parameter maps to a singleton list containing its FQN, while each
295    ``FlatParameter`` maps to a list of its original parameter FQNs, which may
296    have length greater than one.  All FQNs are prefixed starting from ``model``.
297
298    In the case where FSDP was applied with ``use_orig_params=True``, there should be no
299    ``FlatParameter`` s registered to the model's modules and this mapping will only
300    contain mappings from ``nn.Parameter`` s to singleton FQN lists.
301
302    It is only in the case where FSDP was applied with ``use_orig_params=False`` where
303    a ``FlatParameter`` will be registered in place of the original parameters and there
304    will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the
305    original parameters.
306
307    Args:
308        model (torch.nn.Module): Root module (which may or may not be a
309            :class:`FullyShardedDataParallel` instance).
310        dedup_shared_params (bool): For shared parameters, if ``True``, only
311            includes the FQNs corresponding to the first encounter of the
312            shared parameter in the module traversal; if ``False``, then
313            includes the FQNs across all encounters. (Default: ``True``)
314    """
315
316    def module_fn(module, prefix, tree_level, param_to_fqns):
317        for param_name, param in _named_parameters_with_duplicates(
318            module, recurse=False
319        ):
320            local_fqns = (
321                param._fqns
322                if isinstance(param, flat_param_file.FlatParameter)
323                else [param_name]
324            )  # prefixed from `module`
325            global_fqns = [
326                clean_tensor_name(prefix + name) for name in local_fqns
327            ]  # prefixed from the top level `model` (i.e. including `prefix`)
328            is_shared_param = param in param_to_fqns
329            if not is_shared_param:
330                param_to_fqns[param] = global_fqns
331            else:
332                if isinstance(param, flat_param_file.FlatParameter):
333                    # DMP overwrites `named_parameters` and skip (advance to
334                    # the next child module) the wrapped_module (e.g.,
335                    # _dmp_wrapped_module and _fsdp_wrapped_module). When a user
336                    # calls `named_child` to traverse the module recursively and
337                    # calls `named_parameters` with `recurse=False`, parameters
338                    # will be traversed more than once.
339                    # This hack is specified designed for DMP + FSDP. We
340                    # overwrite the flat_parameters traversal result to only obtain
341                    # the last one, which happens to be the correct one.
342                    #
343                    # TODO: Remove this hack once DMP + FSDP is not supported.
344                    warnings.warn(
345                        "FlatParameter is being traversed more than once. "
346                        "This case should only happen when using "
347                        "DistributedModelParallel with FullyShardedDataParallel."
348                    )
349                    param_to_fqns[param] = global_fqns
350                elif not dedup_shared_params:
351                    param_to_fqns[param].extend(global_fqns)
352
353    def return_fn(param_to_fqns):
354        return param_to_fqns
355
356    param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
357    return _apply_to_modules(
358        model,
359        module_fn,
360        return_fn,
361        [key for key, _ in _named_parameters_with_duplicates(model)],
362        param_to_unflat_param_names,
363    )
364
365
366@no_type_check
367def _log_post_backward_hook(
368    state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger
369) -> None:
370    # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for.
371    # Below logging of module names this post-bwd hook fires for can help debug certain
372    # cases where hooks don't fire, such as under certain activation checkpoint configs.
373    if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO:
374        param_fqns = _get_handle_fqns_from_root(state, handle)
375        logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns)
376
377
378@no_type_check
379def _get_handle_fqns_from_root(
380    state: _FSDPState, handle: "FlatParamHandle"
381) -> Optional[List[str]]:
382    if handle is None:
383        return None
384    param_to_fqn = state._exec_order_data.param_to_fqn
385    handle_params = handle.flat_param._params  # only populated for use_orig_params
386    param_fqns = [
387        fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list
388    ]
389    return param_fqns
390
391
392def _apply_to_modules(
393    root_module: torch.nn.Module,
394    module_fn: Callable,
395    return_fn: Callable,
396    filter_fqns: Optional[List[str]] = None,
397    *args,
398    **kwargs,
399):
400    """
401    Performs a pre-order traversal of the modules in the hierarchy rooted at
402    ``root_module``, applying ``module_fn`` at each module and finally
403    returning a value using ``return_fn``. The traversal constructs the full
404    module prefix name (e.g. "module.submodule." just like in model state dict)
405    and makes that available to ``module_fn``.
406
407    ``filter_fqns`` is used because some module may have its own prefix similar
408    to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
409    to remove the prefix.
410    """
411
412    def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
413        # Call the module function before recursing over children (pre-order)
414        module_fn(module, prefix, tree_level, *args, **kwargs)
415        for submodule_name, submodule in module.named_children():
416            if submodule is None:
417                continue
418            new_prefix = prefix + submodule_name + "."
419            new_tree_level = tree_level + 1
420            if filter_fqns is not None:
421                for fqn in filter_fqns:
422                    if fqn.startswith(new_prefix):
423                        break
424                else:
425                    # DMP's named_parameter() will mess up the traversal with
426                    # ``named_children`` + `named_parameter(recurse=False)``.
427                    # This hack is a must to make the traversal work.
428                    # TODO: Remove this hack once DMP + FSDP is not supported.
429                    # It turns out that recursive wrapping may trigger this as
430                    # well.
431                    if (
432                        submodule_name == "_fsdp_wrapped_module"
433                        or submodule_name == "_dmp_wrapped_module"
434                    ):
435                        new_prefix = prefix
436                    elif submodule_name == "module":
437                        new_prefix = prefix
438            f(submodule, new_prefix, new_tree_level, *args, **kwargs)
439
440    f(root_module, "", 0, *args, **kwargs)
441    return return_fn(*args, **kwargs)
442
443
444@no_type_check
445def _assert_in_training_states(
446    state: _FSDPState,
447    training_states: List[TrainingState],
448) -> None:
449    """Asserts that FSDP is in the states ``_training_states``."""
450    # Raise a `ValueError` instead of using `assert` to ensure that these
451    # logical assertions run even if `assert`s are disabled
452    if state.training_state not in training_states:
453        msg = (
454            f"expected to be in states {training_states} but current state is "
455            f"{state.training_state}"
456        )
457        # Print the error on rank 0 in case this is called in the backward pass
458        if state.rank == 0:
459            if isinstance(state, nn.Module):
460                print(f"Asserting FSDP instance is: {state}")
461            print(f"ERROR: {msg}")
462            traceback.print_stack()
463        raise ValueError(msg)
464
465
466def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
467    """
468    Returns:
469        Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
470        parent-less) with respect to the modules in the set itself. In other
471        words, these are the modules in ``modules`` that are not the child of
472        any other module in ``modules``.
473    """
474    root_modules: Set[nn.Module] = set()
475    module_to_submodules = {module: set(module.modules()) for module in modules}
476    for candidate_module in modules:
477        is_root_module = True
478        for module, submodules in module_to_submodules.items():
479            is_child_module = (
480                candidate_module is not module and candidate_module in submodules
481            )
482            if is_child_module:
483                is_root_module = False
484                break
485        if is_root_module:
486            root_modules.add(candidate_module)
487    return root_modules
488
489
490def _override_module_mixed_precision(
491    root: torch.nn.Module,
492    module_classes_to_override: Iterable[Type[nn.Module]],
493    wrap_override_dict: Dict[str, Any] = {"mixed_precision": None},  # noqa: B006
494) -> Set[Type[nn.Module]]:
495    module_classes_to_override = tuple(set(module_classes_to_override))
496    # Return a set of the actually overridden module classes
497    overridden_module_classes: Set[Type[nn.Module]] = set()
498    for mod in root.modules():
499        if isinstance(mod, module_classes_to_override):
500            overridden_module_classes.add(type(mod))
501            mod._wrap_overrides = wrap_override_dict  # type: ignore[assignment]
502            # TODO: We need to run this mixed precision ignored module in fp32,
503            # but ensure subsequent modules, that may possibly be running with
504            # mixed precision, still receive the appropriate precision inputs
505            # without user having to adjust mixed precision config too much.
506            # As a result, we attach pre and post forward hooks to up / down
507            # cast. We should revisit this design.
508
509            def cast_fn(
510                dtype: torch.dtype, module: nn.Module, x: torch.Tensor
511            ) -> torch.Tensor:
512                if not torch.is_floating_point(x) or x.dtype == dtype:
513                    return x
514                _MODULE_TO_INP_DTYPE[module] = x.dtype
515                return x.to(dtype)
516
517            def forward_pre_hook(module, args):
518                return _apply_to_tensors(partial(cast_fn, torch.float32, module), args)
519
520            def forward_post_hook(module, args, output):
521                # NOTE: If the forward did not have any floating-point tensors,
522                # then the dtype will not be set for this module, and we do not
523                # upcast the dtype.
524                if module in _MODULE_TO_INP_DTYPE:
525                    old_dtype = _MODULE_TO_INP_DTYPE[module]
526                    return _apply_to_tensors(
527                        partial(cast_fn, old_dtype, module), output
528                    )
529
530            # We intentionally append both of these hooks so that they run after
531            # all other hooks.
532            mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
533            mod.register_forward_hook(forward_post_hook, prepend=False)
534    return overridden_module_classes
535
536
537def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None:
538    # FIXME record_stream doesn't work with non-cuda/mtia tensors
539    if tensor.device.type not in [
540        "cuda",
541        "mtia",
542        torch._C._get_privateuse1_backend_name(),
543    ]:
544        return
545
546    if torch.distributed._functional_collectives.is_torchdynamo_compiling():
547        return
548        # from @ezyang:
549        # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin
550        # Looking over the PR, it looks like this is because we don't actually support Stream arguments
551        # in torch dispatch, so it just chokes.
552        # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False),
553        # a better version of this would just be to check if there are any modes before disabling dispatch.
554        # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here.
555        tensor.record_stream(stream)
556    else:
557        with no_dispatch():
558            tensor.record_stream(stream)
559