xref: /aosp_15_r20/external/pytorch/torch/distributed/fsdp/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This file includes public APIs for FSDP such as the classes used for the
3constructor arguments.
4"""
5
6from dataclasses import dataclass
7from enum import auto, Enum
8from typing import Optional, Sequence, Type
9
10import torch
11from torch.nn.modules.batchnorm import _BatchNorm
12
13
14__all__ = [
15    "ShardingStrategy",
16    "BackwardPrefetch",
17    "MixedPrecision",
18    "CPUOffload",
19    "StateDictType",
20    "StateDictConfig",
21    "FullStateDictConfig",
22    "LocalStateDictConfig",
23    "ShardedStateDictConfig",
24    "OptimStateDictConfig",
25    "FullOptimStateDictConfig",
26    "LocalOptimStateDictConfig",
27    "ShardedOptimStateDictConfig",
28    "StateDictSettings",
29]
30
31
32class ShardingStrategy(Enum):
33    """
34    This specifies the sharding strategy to be used for distributed training by
35    :class:`FullyShardedDataParallel`.
36
37    - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
38      For the parameters, this strategy unshards (via all-gather) before the
39      forward, reshards after the forward, unshards before the backward
40      computation, and reshards after the backward computation. For gradients,
41      it synchronizes and shards them (via reduce-scatter) after the backward
42      computation. The sharded optimizer states are updated locally per rank.
43    - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
44      computation, and additionally, parameters are sharded outside
45      computation. For the parameters, this strategy unshards before the
46      forward, does not reshard them after the forward, and only reshards them
47      after the backward computation. The sharded optimizer states are updated
48      locally per rank. Inside ``no_sync()``, the parameters are not resharded
49      after the backward computation.
50    - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
51      but instead replicated across ranks similar to PyTorch's
52      :class:`DistributedDataParallel` API. For gradients, this strategy
53      synchronizes them (via all-reduce) after the backward computation. The
54      unsharded optimizer states are updated locally per rank.
55    - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
56      nodes. This results in reduced communication volume as expensive all-gathers and
57      reduce-scatters are only done within a node, which can be more performant for medium
58      -sized models.
59    - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
60      nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
61      since the unsharded parameters are not freed after the forward pass, saving the
62      all-gathers in the pre-backward.
63    """
64
65    FULL_SHARD = auto()
66    SHARD_GRAD_OP = auto()
67    NO_SHARD = auto()
68    HYBRID_SHARD = auto()
69    _HYBRID_SHARD_ZERO2 = auto()
70
71
72class BackwardPrefetch(Enum):
73    """
74    This configures explicit backward prefetching, which improves throughput by
75    enabling communication and computation overlap in the backward pass at the
76    cost of slightly increased memory usage.
77
78    - ``BACKWARD_PRE``: This enables the most overlap but increases memory
79      usage the most. This prefetches the next set of parameters *before* the
80      current set of parameters' gradient computation. This overlaps the *next
81      all-gather* and the *current gradient computation*, and at the peak, it
82      holds the current set of parameters, next set of parameters, and current
83      set of gradients in memory.
84    - ``BACKWARD_POST``: This enables less overlap but requires less memory
85      usage. This prefetches the next set of parameters *after* the current
86      set of parameters' gradient computation. This overlaps the *current
87      reduce-scatter* and the *next gradient computation*, and it frees the
88      current set of parameters before allocating memory for the next set of
89      parameters, only holding the next set of parameters and current set of
90      gradients in memory at the peak.
91    - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables
92      the backward prefetching altogether. This has no overlap and does not
93      increase memory usage. In general, we do not recommend this setting since
94      it may degrade throughput significantly.
95
96    For more technical context: For a single process group using NCCL backend,
97    any collectives, even if issued from different streams, contend for the
98    same per-device NCCL stream, which implies that the relative order in which
99    the collectives are issued matters for overlapping. The two backward
100    prefetching values correspond to different issue orders.
101    """
102
103    # NOTE: For both modes, the ordering that defines "current" and "next" is
104    # not always exact in the current implementation. A mistargeted prefetch
105    # simply means that the parameter memory is allocated earlier than needed,
106    # possibly increasing peak memory usage, but does not affect correctness.
107    BACKWARD_PRE = auto()
108    BACKWARD_POST = auto()
109
110
111@dataclass
112class MixedPrecision:
113    """
114    This configures FSDP-native mixed precision training.
115
116    Attributes:
117        param_dtype (Optional[torch.dtype]): This specifies the dtype for model
118            parameters during forward and backward and thus the dtype for
119            forward and backward computation. Outside forward and backward, the
120            *sharded* parameters are kept in full precision (e.g. for the
121            optimizer step), and for model checkpointing, the parameters are
122            always saved in full precision. (Default: ``None``)
123        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
124            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
125            ``None`` but ``param_dtype`` is not ``None``, then this takes on
126            the ``param_dtype`` value, still running gradient reduction in low
127            precision. This is permitted to differ from ``param_dtype``, e.g.
128            to force gradient reduction to run in full precision. (Default:
129            ``None``)
130        buffer_dtype (Optional[torch.dtype]): This specifies the dtype for
131            buffers. FSDP does not shard buffers. Rather, FSDP casts them to
132            ``buffer_dtype`` in the first forward pass and keeps them in that
133            dtype thereafter. For model checkpointing, the buffers are saved
134            in full precision except for ``LOCAL_STATE_DICT``. (Default:
135            ``None``)
136        keep_low_precision_grads (bool): If ``False``, then FSDP upcasts
137            gradients to full precision after the backward pass in preparation
138            for the optimizer step. If ``True``, then FSDP keeps the gradients
139            in the dtype used for gradient reduction, which can save memory if
140            using a custom optimizer that supports running in low precision.
141            (Default: ``False``)
142        cast_forward_inputs (bool): If ``True``, then this FSDP module casts
143            its forward args and kwargs to ``param_dtype``. This is to ensure
144            that parameter and input dtypes match for forward computation, as
145            required by many ops. This may need to be set to ``True`` when only
146            applying mixed precision to some but not all FSDP modules, in which
147            case a mixed-precision FSDP submodule needs to recast its inputs.
148            (Default: ``False``)
149        cast_root_forward_inputs (bool): If ``True``, then the root FSDP module
150            casts its forward args and kwargs to ``param_dtype``, overriding
151            the value of ``cast_forward_inputs``. For non-root FSDP modules,
152            this does not do anything. (Default: ``True``)
153        _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies
154            module classes to ignore for mixed precision when using an
155            ``auto_wrap_policy``: Modules of these classes will have FSDP
156            applied to them separately with mixed precision disabled (meaning
157            that the final FSDP construction would deviate from the specified
158            policy). If ``auto_wrap_policy`` is not specified, then this does
159            not do anything. This API is experimental and subject to change.
160            (Default: ``(_BatchNorm,)``)
161
162    .. note:: This API is experimental and subject to change.
163
164    .. note:: Only floating point tensors are cast to their specified dtypes.
165
166    .. note:: In ``summon_full_params``, parameters are forced to full
167        precision, but buffers are not.
168
169    .. note:: Layer norm and batch norm accumulate in ``float32`` even when
170        their inputs are in a low precision like ``float16`` or ``bfloat16``.
171        Disabling FSDP's mixed precision for those norm modules only means that
172        the affine parameters are kept in ``float32``. However, this incurs
173        separate all-gathers and reduce-scatters for those norm modules, which
174        may be inefficient, so if the workload permits, the user should prefer
175        to still apply mixed precision to those modules.
176
177    .. note:: By default, if the user passes a model with any ``_BatchNorm``
178        modules and specifies an ``auto_wrap_policy``, then the batch norm
179        modules will have FSDP applied to them separately with mixed precision
180        disabled. See the ``_module_classes_to_ignore`` argument.
181
182    .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
183        ``cast_forward_inputs=False`` by default. For the root FSDP instance,
184        its ``cast_root_forward_inputs`` takes precedence over its
185        ``cast_forward_inputs``. For non-root FSDP instances, their
186        ``cast_root_forward_inputs`` values are ignored. The default setting is
187        sufficient for the typical case where each FSDP instance has the same
188        ``MixedPrecision`` configuration and only needs to cast inputs to the
189        ``param_dtype`` at the beginning of the model's forward pass.
190
191    .. note:: For nested FSDP instances with different ``MixedPrecision``
192        configurations, we recommend setting individual ``cast_forward_inputs``
193        values to configure casting inputs or not before each instance's
194        forward. In such a case, since the casts happen before each FSDP
195        instance's forward, a parent FSDP instance should have its non-FSDP
196        submodules run before its FSDP submodules to avoid the activation dtype
197        being changed due to a different ``MixedPrecision`` configuration.
198
199        Example::
200
201            >>> # xdoctest: +SKIP("undefined variables")
202            >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
203            >>> model[1] = FSDP(
204            >>>     model[1],
205            >>>     mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
206            >>> )
207            >>> model = FSDP(
208            >>>     model,
209            >>>     mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
210            >>> )
211
212        The above shows a working example. On the other hand, if ``model[1]``
213        were replaced with ``model[0]``, meaning that the submodule using
214        different ``MixedPrecision`` ran its forward first, then ``model[1]``
215        would incorrectly see ``float16`` activations instead of ``bfloat16``
216        ones.
217
218    """
219
220    param_dtype: Optional[torch.dtype] = None
221    reduce_dtype: Optional[torch.dtype] = None
222    buffer_dtype: Optional[torch.dtype] = None
223    keep_low_precision_grads: bool = False
224    cast_forward_inputs: bool = False
225    cast_root_forward_inputs: bool = True
226    _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,)
227
228
229@dataclass
230class CPUOffload:
231    """
232    This configures CPU offloading.
233
234    Attributes:
235        offload_params (bool): This specifies whether to offload parameters to
236            CPU when not involved in computation. If ``True``, then this
237            offloads gradients to CPU as well, meaning that the optimizer step
238            runs on CPU.
239    """
240
241    offload_params: bool = False
242
243
244class StateDictType(Enum):
245    """
246    This enum indicates that which type of ``state_dict`` the FSDP module is
247    currently processing (returning or loading).
248    The default value is FULL_STATE_DICT to comply the PyTorch convention.
249    ..note::
250        FSDP currently supports three types of ``state_dict``:
251            1. ``state_dict/load_state_dict`: this pair of APIs return and load
252               the non-sharded, unflattened parameters. The semantics is the
253               same as using DDP.
254            2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
255               and load local sharded, flattened parameters. The values returned
256               by ``_local_state_dict`` can be directly used by FSDP and is only
257               meaningful to FSDP (because parameters are flattened). Note that
258               these APIs are meant for use via the :func:`state_dict_type`
259               context manager as follows:
260                   >>> # xdoctest: +SKIP("undefined variables")
261                   >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
262                   ...     state = fsdp.state_dict()  # loads local state dict
263            3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
264               return and load sharded, unflattened parameters. The ``state_dict``
265               return by ``sharded_state_dict`` can be used by all other parallel
266               schemes (resharding may be required).
267    """
268
269    FULL_STATE_DICT = auto()
270    LOCAL_STATE_DICT = auto()
271    SHARDED_STATE_DICT = auto()
272
273
274@dataclass
275class StateDictConfig:
276    """
277    ``StateDictConfig`` is the base class for all ``state_dict`` configuration
278    classes. Users should instantiate a child class (e.g.
279    ``FullStateDictConfig``) in order to configure settings for the
280    corresponding ``state_dict`` type supported by FSDP.
281
282    Attributes:
283        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict
284            values to CPU, and if ``False``, then FSDP keeps them on GPU.
285            (Default: ``False``)
286    """
287
288    offload_to_cpu: bool = False
289
290
291@dataclass
292class FullStateDictConfig(StateDictConfig):
293    """
294    ``FullStateDictConfig`` is a config class meant to be used with
295    ``StateDictType.FULL_STATE_DICT``. We recommend enabling both
296    ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state
297    dicts to save GPU memory and CPU memory, respectively. This config class
298    is meant to be used via the :func:`state_dict_type` context manager as
299    follows:
300
301        >>> # xdoctest: +SKIP("undefined variables")
302        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
303        >>> fsdp = FSDP(model, auto_wrap_policy=...)
304        >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
305        >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
306        >>>     state = fsdp.state_dict()
307        >>>     # `state` will be empty on non rank 0 and contain CPU tensors on rank 0.
308        >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
309        >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP
310        >>> if dist.get_rank() == 0:
311        >>>     # Load checkpoint only on rank 0 to avoid memory redundancy
312        >>>     state_dict = torch.load("my_checkpoint.pt")
313        >>>     model.load_state_dict(state_dict)
314        >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument
315        >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
316        >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
317        >>> # After this point, all ranks have FSDP model with loaded checkpoint.
318
319    Attributes:
320        rank0_only (bool): If ``True``, then only rank 0 saves the full state
321            dict, and nonzero ranks save an empty dict. If ``False``, then all
322            ranks save the full state dict. (Default: ``False``)
323    """
324
325    rank0_only: bool = False
326
327
328@dataclass
329class LocalStateDictConfig(StateDictConfig):
330    pass
331
332
333@dataclass
334class ShardedStateDictConfig(StateDictConfig):
335    """
336    ``ShardedStateDictConfig`` is a config class meant to be used with
337    ``StateDictType.SHARDED_STATE_DICT``.
338
339    Attributes:
340        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
341            as ``DTensor``, and if ``False``, then FSDP saves them as
342            ``ShardedTensor``. (Default: ``False``)
343
344    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig`
345      and it is used by FSDP to determine the type of state dict values. Users should not
346      manually modify ``_use_dtensor``.
347    """
348
349    _use_dtensor: bool = False
350
351
352@dataclass
353class OptimStateDictConfig:
354    """
355    ``OptimStateDictConfig`` is the base class for all ``optim_state_dict``
356    configuration classes.  Users should instantiate a child class (e.g.
357    ``FullOptimStateDictConfig``) in order to configure settings for the
358    corresponding ``optim_state_dict`` type supported by FSDP.
359
360    Attributes:
361        offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's
362            tensor values to CPU, and if ``False``, then FSDP keeps them on the
363            original device (which is GPU unless parameter CPU offloading is
364            enabled). (Default: ``True``)
365    """
366
367    offload_to_cpu: bool = True
368
369
370@dataclass
371class FullOptimStateDictConfig(OptimStateDictConfig):
372    """
373    Attributes:
374        rank0_only (bool): If ``True``, then only rank 0 saves the full state
375            dict, and nonzero ranks save an empty dict. If ``False``, then all
376            ranks save the full state dict. (Default: ``False``)
377    """
378
379    rank0_only: bool = False
380
381
382@dataclass
383class LocalOptimStateDictConfig(OptimStateDictConfig):
384    offload_to_cpu: bool = False
385
386
387@dataclass
388class ShardedOptimStateDictConfig(OptimStateDictConfig):
389    """
390    ``ShardedOptimStateDictConfig`` is a config class meant to be used with
391    ``StateDictType.SHARDED_STATE_DICT``.
392
393    Attributes:
394        _use_dtensor (bool): If ``True``, then FSDP saves the state dict values
395            as ``DTensor``, and if ``False``, then FSDP saves them as
396            ``ShardedTensor``. (Default: ``False``)
397
398    .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig`
399      and it is used by FSDP to determine the type of state dict values. Users should not
400      manually modify ``_use_dtensor``.
401    """
402
403    _use_dtensor: bool = False
404
405
406@dataclass
407class StateDictSettings:
408    state_dict_type: StateDictType
409    state_dict_config: StateDictConfig
410    optim_state_dict_config: OptimStateDictConfig
411