1""" 2This file includes public APIs for FSDP such as the classes used for the 3constructor arguments. 4""" 5 6from dataclasses import dataclass 7from enum import auto, Enum 8from typing import Optional, Sequence, Type 9 10import torch 11from torch.nn.modules.batchnorm import _BatchNorm 12 13 14__all__ = [ 15 "ShardingStrategy", 16 "BackwardPrefetch", 17 "MixedPrecision", 18 "CPUOffload", 19 "StateDictType", 20 "StateDictConfig", 21 "FullStateDictConfig", 22 "LocalStateDictConfig", 23 "ShardedStateDictConfig", 24 "OptimStateDictConfig", 25 "FullOptimStateDictConfig", 26 "LocalOptimStateDictConfig", 27 "ShardedOptimStateDictConfig", 28 "StateDictSettings", 29] 30 31 32class ShardingStrategy(Enum): 33 """ 34 This specifies the sharding strategy to be used for distributed training by 35 :class:`FullyShardedDataParallel`. 36 37 - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. 38 For the parameters, this strategy unshards (via all-gather) before the 39 forward, reshards after the forward, unshards before the backward 40 computation, and reshards after the backward computation. For gradients, 41 it synchronizes and shards them (via reduce-scatter) after the backward 42 computation. The sharded optimizer states are updated locally per rank. 43 - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during 44 computation, and additionally, parameters are sharded outside 45 computation. For the parameters, this strategy unshards before the 46 forward, does not reshard them after the forward, and only reshards them 47 after the backward computation. The sharded optimizer states are updated 48 locally per rank. Inside ``no_sync()``, the parameters are not resharded 49 after the backward computation. 50 - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded 51 but instead replicated across ranks similar to PyTorch's 52 :class:`DistributedDataParallel` API. For gradients, this strategy 53 synchronizes them (via all-reduce) after the backward computation. The 54 unsharded optimizer states are updated locally per rank. 55 - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across 56 nodes. This results in reduced communication volume as expensive all-gathers and 57 reduce-scatters are only done within a node, which can be more performant for medium 58 -sized models. 59 - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across 60 nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput 61 since the unsharded parameters are not freed after the forward pass, saving the 62 all-gathers in the pre-backward. 63 """ 64 65 FULL_SHARD = auto() 66 SHARD_GRAD_OP = auto() 67 NO_SHARD = auto() 68 HYBRID_SHARD = auto() 69 _HYBRID_SHARD_ZERO2 = auto() 70 71 72class BackwardPrefetch(Enum): 73 """ 74 This configures explicit backward prefetching, which improves throughput by 75 enabling communication and computation overlap in the backward pass at the 76 cost of slightly increased memory usage. 77 78 - ``BACKWARD_PRE``: This enables the most overlap but increases memory 79 usage the most. This prefetches the next set of parameters *before* the 80 current set of parameters' gradient computation. This overlaps the *next 81 all-gather* and the *current gradient computation*, and at the peak, it 82 holds the current set of parameters, next set of parameters, and current 83 set of gradients in memory. 84 - ``BACKWARD_POST``: This enables less overlap but requires less memory 85 usage. This prefetches the next set of parameters *after* the current 86 set of parameters' gradient computation. This overlaps the *current 87 reduce-scatter* and the *next gradient computation*, and it frees the 88 current set of parameters before allocating memory for the next set of 89 parameters, only holding the next set of parameters and current set of 90 gradients in memory at the peak. 91 - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables 92 the backward prefetching altogether. This has no overlap and does not 93 increase memory usage. In general, we do not recommend this setting since 94 it may degrade throughput significantly. 95 96 For more technical context: For a single process group using NCCL backend, 97 any collectives, even if issued from different streams, contend for the 98 same per-device NCCL stream, which implies that the relative order in which 99 the collectives are issued matters for overlapping. The two backward 100 prefetching values correspond to different issue orders. 101 """ 102 103 # NOTE: For both modes, the ordering that defines "current" and "next" is 104 # not always exact in the current implementation. A mistargeted prefetch 105 # simply means that the parameter memory is allocated earlier than needed, 106 # possibly increasing peak memory usage, but does not affect correctness. 107 BACKWARD_PRE = auto() 108 BACKWARD_POST = auto() 109 110 111@dataclass 112class MixedPrecision: 113 """ 114 This configures FSDP-native mixed precision training. 115 116 Attributes: 117 param_dtype (Optional[torch.dtype]): This specifies the dtype for model 118 parameters during forward and backward and thus the dtype for 119 forward and backward computation. Outside forward and backward, the 120 *sharded* parameters are kept in full precision (e.g. for the 121 optimizer step), and for model checkpointing, the parameters are 122 always saved in full precision. (Default: ``None``) 123 reduce_dtype (Optional[torch.dtype]): This specifies the dtype for 124 gradient reduction (i.e. reduce-scatter or all-reduce). If this is 125 ``None`` but ``param_dtype`` is not ``None``, then this takes on 126 the ``param_dtype`` value, still running gradient reduction in low 127 precision. This is permitted to differ from ``param_dtype``, e.g. 128 to force gradient reduction to run in full precision. (Default: 129 ``None``) 130 buffer_dtype (Optional[torch.dtype]): This specifies the dtype for 131 buffers. FSDP does not shard buffers. Rather, FSDP casts them to 132 ``buffer_dtype`` in the first forward pass and keeps them in that 133 dtype thereafter. For model checkpointing, the buffers are saved 134 in full precision except for ``LOCAL_STATE_DICT``. (Default: 135 ``None``) 136 keep_low_precision_grads (bool): If ``False``, then FSDP upcasts 137 gradients to full precision after the backward pass in preparation 138 for the optimizer step. If ``True``, then FSDP keeps the gradients 139 in the dtype used for gradient reduction, which can save memory if 140 using a custom optimizer that supports running in low precision. 141 (Default: ``False``) 142 cast_forward_inputs (bool): If ``True``, then this FSDP module casts 143 its forward args and kwargs to ``param_dtype``. This is to ensure 144 that parameter and input dtypes match for forward computation, as 145 required by many ops. This may need to be set to ``True`` when only 146 applying mixed precision to some but not all FSDP modules, in which 147 case a mixed-precision FSDP submodule needs to recast its inputs. 148 (Default: ``False``) 149 cast_root_forward_inputs (bool): If ``True``, then the root FSDP module 150 casts its forward args and kwargs to ``param_dtype``, overriding 151 the value of ``cast_forward_inputs``. For non-root FSDP modules, 152 this does not do anything. (Default: ``True``) 153 _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies 154 module classes to ignore for mixed precision when using an 155 ``auto_wrap_policy``: Modules of these classes will have FSDP 156 applied to them separately with mixed precision disabled (meaning 157 that the final FSDP construction would deviate from the specified 158 policy). If ``auto_wrap_policy`` is not specified, then this does 159 not do anything. This API is experimental and subject to change. 160 (Default: ``(_BatchNorm,)``) 161 162 .. note:: This API is experimental and subject to change. 163 164 .. note:: Only floating point tensors are cast to their specified dtypes. 165 166 .. note:: In ``summon_full_params``, parameters are forced to full 167 precision, but buffers are not. 168 169 .. note:: Layer norm and batch norm accumulate in ``float32`` even when 170 their inputs are in a low precision like ``float16`` or ``bfloat16``. 171 Disabling FSDP's mixed precision for those norm modules only means that 172 the affine parameters are kept in ``float32``. However, this incurs 173 separate all-gathers and reduce-scatters for those norm modules, which 174 may be inefficient, so if the workload permits, the user should prefer 175 to still apply mixed precision to those modules. 176 177 .. note:: By default, if the user passes a model with any ``_BatchNorm`` 178 modules and specifies an ``auto_wrap_policy``, then the batch norm 179 modules will have FSDP applied to them separately with mixed precision 180 disabled. See the ``_module_classes_to_ignore`` argument. 181 182 .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and 183 ``cast_forward_inputs=False`` by default. For the root FSDP instance, 184 its ``cast_root_forward_inputs`` takes precedence over its 185 ``cast_forward_inputs``. For non-root FSDP instances, their 186 ``cast_root_forward_inputs`` values are ignored. The default setting is 187 sufficient for the typical case where each FSDP instance has the same 188 ``MixedPrecision`` configuration and only needs to cast inputs to the 189 ``param_dtype`` at the beginning of the model's forward pass. 190 191 .. note:: For nested FSDP instances with different ``MixedPrecision`` 192 configurations, we recommend setting individual ``cast_forward_inputs`` 193 values to configure casting inputs or not before each instance's 194 forward. In such a case, since the casts happen before each FSDP 195 instance's forward, a parent FSDP instance should have its non-FSDP 196 submodules run before its FSDP submodules to avoid the activation dtype 197 being changed due to a different ``MixedPrecision`` configuration. 198 199 Example:: 200 201 >>> # xdoctest: +SKIP("undefined variables") 202 >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) 203 >>> model[1] = FSDP( 204 >>> model[1], 205 >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), 206 >>> ) 207 >>> model = FSDP( 208 >>> model, 209 >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), 210 >>> ) 211 212 The above shows a working example. On the other hand, if ``model[1]`` 213 were replaced with ``model[0]``, meaning that the submodule using 214 different ``MixedPrecision`` ran its forward first, then ``model[1]`` 215 would incorrectly see ``float16`` activations instead of ``bfloat16`` 216 ones. 217 218 """ 219 220 param_dtype: Optional[torch.dtype] = None 221 reduce_dtype: Optional[torch.dtype] = None 222 buffer_dtype: Optional[torch.dtype] = None 223 keep_low_precision_grads: bool = False 224 cast_forward_inputs: bool = False 225 cast_root_forward_inputs: bool = True 226 _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,) 227 228 229@dataclass 230class CPUOffload: 231 """ 232 This configures CPU offloading. 233 234 Attributes: 235 offload_params (bool): This specifies whether to offload parameters to 236 CPU when not involved in computation. If ``True``, then this 237 offloads gradients to CPU as well, meaning that the optimizer step 238 runs on CPU. 239 """ 240 241 offload_params: bool = False 242 243 244class StateDictType(Enum): 245 """ 246 This enum indicates that which type of ``state_dict`` the FSDP module is 247 currently processing (returning or loading). 248 The default value is FULL_STATE_DICT to comply the PyTorch convention. 249 ..note:: 250 FSDP currently supports three types of ``state_dict``: 251 1. ``state_dict/load_state_dict`: this pair of APIs return and load 252 the non-sharded, unflattened parameters. The semantics is the 253 same as using DDP. 254 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return 255 and load local sharded, flattened parameters. The values returned 256 by ``_local_state_dict`` can be directly used by FSDP and is only 257 meaningful to FSDP (because parameters are flattened). Note that 258 these APIs are meant for use via the :func:`state_dict_type` 259 context manager as follows: 260 >>> # xdoctest: +SKIP("undefined variables") 261 >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): 262 ... state = fsdp.state_dict() # loads local state dict 263 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs 264 return and load sharded, unflattened parameters. The ``state_dict`` 265 return by ``sharded_state_dict`` can be used by all other parallel 266 schemes (resharding may be required). 267 """ 268 269 FULL_STATE_DICT = auto() 270 LOCAL_STATE_DICT = auto() 271 SHARDED_STATE_DICT = auto() 272 273 274@dataclass 275class StateDictConfig: 276 """ 277 ``StateDictConfig`` is the base class for all ``state_dict`` configuration 278 classes. Users should instantiate a child class (e.g. 279 ``FullStateDictConfig``) in order to configure settings for the 280 corresponding ``state_dict`` type supported by FSDP. 281 282 Attributes: 283 offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict 284 values to CPU, and if ``False``, then FSDP keeps them on GPU. 285 (Default: ``False``) 286 """ 287 288 offload_to_cpu: bool = False 289 290 291@dataclass 292class FullStateDictConfig(StateDictConfig): 293 """ 294 ``FullStateDictConfig`` is a config class meant to be used with 295 ``StateDictType.FULL_STATE_DICT``. We recommend enabling both 296 ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state 297 dicts to save GPU memory and CPU memory, respectively. This config class 298 is meant to be used via the :func:`state_dict_type` context manager as 299 follows: 300 301 >>> # xdoctest: +SKIP("undefined variables") 302 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 303 >>> fsdp = FSDP(model, auto_wrap_policy=...) 304 >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 305 >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): 306 >>> state = fsdp.state_dict() 307 >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. 308 >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: 309 >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP 310 >>> if dist.get_rank() == 0: 311 >>> # Load checkpoint only on rank 0 to avoid memory redundancy 312 >>> state_dict = torch.load("my_checkpoint.pt") 313 >>> model.load_state_dict(state_dict) 314 >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument 315 >>> # communicates loaded checkpoint states from rank 0 to rest of the world. 316 >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) 317 >>> # After this point, all ranks have FSDP model with loaded checkpoint. 318 319 Attributes: 320 rank0_only (bool): If ``True``, then only rank 0 saves the full state 321 dict, and nonzero ranks save an empty dict. If ``False``, then all 322 ranks save the full state dict. (Default: ``False``) 323 """ 324 325 rank0_only: bool = False 326 327 328@dataclass 329class LocalStateDictConfig(StateDictConfig): 330 pass 331 332 333@dataclass 334class ShardedStateDictConfig(StateDictConfig): 335 """ 336 ``ShardedStateDictConfig`` is a config class meant to be used with 337 ``StateDictType.SHARDED_STATE_DICT``. 338 339 Attributes: 340 _use_dtensor (bool): If ``True``, then FSDP saves the state dict values 341 as ``DTensor``, and if ``False``, then FSDP saves them as 342 ``ShardedTensor``. (Default: ``False``) 343 344 .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` 345 and it is used by FSDP to determine the type of state dict values. Users should not 346 manually modify ``_use_dtensor``. 347 """ 348 349 _use_dtensor: bool = False 350 351 352@dataclass 353class OptimStateDictConfig: 354 """ 355 ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` 356 configuration classes. Users should instantiate a child class (e.g. 357 ``FullOptimStateDictConfig``) in order to configure settings for the 358 corresponding ``optim_state_dict`` type supported by FSDP. 359 360 Attributes: 361 offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's 362 tensor values to CPU, and if ``False``, then FSDP keeps them on the 363 original device (which is GPU unless parameter CPU offloading is 364 enabled). (Default: ``True``) 365 """ 366 367 offload_to_cpu: bool = True 368 369 370@dataclass 371class FullOptimStateDictConfig(OptimStateDictConfig): 372 """ 373 Attributes: 374 rank0_only (bool): If ``True``, then only rank 0 saves the full state 375 dict, and nonzero ranks save an empty dict. If ``False``, then all 376 ranks save the full state dict. (Default: ``False``) 377 """ 378 379 rank0_only: bool = False 380 381 382@dataclass 383class LocalOptimStateDictConfig(OptimStateDictConfig): 384 offload_to_cpu: bool = False 385 386 387@dataclass 388class ShardedOptimStateDictConfig(OptimStateDictConfig): 389 """ 390 ``ShardedOptimStateDictConfig`` is a config class meant to be used with 391 ``StateDictType.SHARDED_STATE_DICT``. 392 393 Attributes: 394 _use_dtensor (bool): If ``True``, then FSDP saves the state dict values 395 as ``DTensor``, and if ``False``, then FSDP saves them as 396 ``ShardedTensor``. (Default: ``False``) 397 398 .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` 399 and it is used by FSDP to determine the type of state dict values. Users should not 400 manually modify ``_use_dtensor``. 401 """ 402 403 _use_dtensor: bool = False 404 405 406@dataclass 407class StateDictSettings: 408 state_dict_type: StateDictType 409 state_dict_config: StateDictConfig 410 optim_state_dict_config: OptimStateDictConfig 411