1# mypy: allow-untyped-defs 2from dataclasses import dataclass 3from typing import Optional 4 5import torch 6 7 8@dataclass(frozen=True) 9class MixedPrecisionPolicy: 10 """ 11 This configures FSDP's mixed precision. Unlike autocast, this applies mixed 12 precision at the module level, not op level, which means low-precision 13 activations are saved for backward and high-to-low-precision casts are 14 incurred only at module boundaries. 15 16 FSDP works well with module-level mixed precision since it keeps the 17 high-precision sharded parameters in memory anyway. In other words, FSDP 18 does not require any extra memory to keep a high-precision copy of the 19 parameters for the optimizer step. 20 21 Attributes: 22 param_dtype (Optional[torch.dtype]): This specifies the dtype for 23 the unsharded parameter and hence the dtype for forward/backward 24 computation and the parameter all-gather. If this is ``None``, then 25 the unsharded parameter uses the original dtype. The optimizer step 26 uses the sharded parameter in the original dtype. (Default: 27 ``None``) 28 reduce_dtype (Optional[torch.dtype]): This specifies the dtype for 29 gradient reduction (i.e. reduce-scatter or all-reduce). If this is 30 ``None`` but ``param_dtype`` is not ``None``, then the reduction 31 uses the compute dtype. This can be used to run gradient reduction 32 in full precision while using low precision for compute. If also 33 gradient reduction is disabled via :meth:`set_requires_gradient_sync`, 34 then FSDP will accumulate gradients using ``reduce_dtype``. 35 (Default: ``None``) 36 output_dtype (Optional[torch.dtype]): This specifies the dtype for 37 casting floating-point forward outputs. This can be used to 38 help implement cases where different modules have different mixed 39 precision policies. (Default: ``None``) 40 cast_forward_inputs (bool): This specifies whether FSDP should cast the 41 forward's floating-point input tensors to ``param_dtype`` or not. 42 """ 43 44 param_dtype: Optional[torch.dtype] = None 45 reduce_dtype: Optional[torch.dtype] = None 46 output_dtype: Optional[torch.dtype] = None 47 cast_forward_inputs: bool = True 48 49 def __post_init__(self): 50 # Clamp `reduce_dtype` to `None` if no casting is required: since 51 # gradients are computed in `param_dtype`, if `reduce_dtype` matches, 52 # then we do not need extra casting 53 if self.param_dtype == self.reduce_dtype: 54 # Bypass the frozen dataclass checks 55 object.__setattr__(self, "reduce_dtype", None) 56 57 58@dataclass 59class OffloadPolicy: 60 """This base class represents the policy of no offloading.""" 61 62 63@dataclass 64class CPUOffloadPolicy(OffloadPolicy): 65 """ 66 This offload policy offloads parameters, gradients, and optimizer states to 67 CPU. Sharded parameters are copied host-to-device before all-gather. The 68 all-gathered parameters are freed according to ``reshard_after_forward``. 69 Sharded gradients are copied device-to-host in backward, and the optimizer 70 step runs on CPU with CPU optimizer states. 71 72 Attributes: 73 pin_memory (bool): Whether to pin sharded parameter and gradient 74 memory. Pinning memory allows H2D/D2H copying without blocking the 75 CPU and in turn, overlap with compute, but pinned memory cannot be 76 used by other processes. Set this to ``False`` if you have 77 insufficient CPU memory. (Default: ``True``) 78 """ 79 80 pin_memory: bool = True 81