xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable/fsdp/_fsdp_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from typing import Optional
4
5import torch
6
7
8@dataclass(frozen=True)
9class MixedPrecisionPolicy:
10    """
11    This configures FSDP's mixed precision. Unlike autocast, this applies mixed
12    precision at the module level, not op level, which means low-precision
13    activations are saved for backward and high-to-low-precision casts are
14    incurred only at module boundaries.
15
16    FSDP works well with module-level mixed precision since it keeps the
17    high-precision sharded parameters in memory anyway. In other words, FSDP
18    does not require any extra memory to keep a high-precision copy of the
19    parameters for the optimizer step.
20
21    Attributes:
22        param_dtype (Optional[torch.dtype]): This specifies the dtype for
23            the unsharded parameter and hence the dtype for forward/backward
24            computation and the parameter all-gather. If this is ``None``, then
25            the unsharded parameter uses the original dtype. The optimizer step
26            uses the sharded parameter in the original dtype. (Default:
27            ``None``)
28        reduce_dtype (Optional[torch.dtype]): This specifies the dtype for
29            gradient reduction (i.e. reduce-scatter or all-reduce). If this is
30            ``None`` but ``param_dtype`` is not ``None``, then the reduction
31            uses the compute dtype. This can be used to run gradient reduction
32            in full precision while using low precision for compute. If also
33            gradient reduction is disabled via :meth:`set_requires_gradient_sync`,
34            then FSDP will accumulate gradients using ``reduce_dtype``.
35            (Default: ``None``)
36        output_dtype (Optional[torch.dtype]): This specifies the dtype for
37            casting floating-point forward outputs. This can be used to
38            help implement cases where different modules have different mixed
39            precision policies. (Default: ``None``)
40        cast_forward_inputs (bool): This specifies whether FSDP should cast the
41            forward's floating-point input tensors to ``param_dtype`` or not.
42    """
43
44    param_dtype: Optional[torch.dtype] = None
45    reduce_dtype: Optional[torch.dtype] = None
46    output_dtype: Optional[torch.dtype] = None
47    cast_forward_inputs: bool = True
48
49    def __post_init__(self):
50        # Clamp `reduce_dtype` to `None` if no casting is required: since
51        # gradients are computed in `param_dtype`, if `reduce_dtype` matches,
52        # then we do not need extra casting
53        if self.param_dtype == self.reduce_dtype:
54            # Bypass the frozen dataclass checks
55            object.__setattr__(self, "reduce_dtype", None)
56
57
58@dataclass
59class OffloadPolicy:
60    """This base class represents the policy of no offloading."""
61
62
63@dataclass
64class CPUOffloadPolicy(OffloadPolicy):
65    """
66    This offload policy offloads parameters, gradients, and optimizer states to
67    CPU. Sharded parameters are copied host-to-device before all-gather. The
68    all-gathered parameters are freed according to ``reshard_after_forward``.
69    Sharded gradients are copied device-to-host in backward, and the optimizer
70    step runs on CPU with CPU optimizer states.
71
72    Attributes:
73        pin_memory (bool): Whether to pin sharded parameter and gradient
74            memory. Pinning memory allows H2D/D2H copying without blocking the
75            CPU and in turn, overlap with compute, but pinned memory cannot be
76            used by other processes. Set this to ``False`` if you have
77            insufficient CPU memory. (Default: ``True``)
78    """
79
80    pin_memory: bool = True
81