1# mypy: ignore-errors 2 3import contextlib 4import copy 5import functools 6import math 7import traceback 8import warnings 9from contextlib import contextmanager 10from enum import auto, Enum 11from typing import ( 12 Any, 13 Callable, 14 Dict, 15 Generator, 16 Iterable, 17 Iterator, 18 List, 19 Optional, 20 Tuple, 21 Union, 22) 23 24import torch 25import torch.distributed as dist 26import torch.distributed.fsdp._traversal_utils as traversal_utils 27import torch.nn as nn 28from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 29 _CHECKPOINT_WRAPPED_MODULE, 30 ActivationWrapper, 31) 32from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS 33from torch.distributed.fsdp._common_utils import ( 34 _FSDPState, 35 _get_param_to_fqns, 36 FSDP_PREFIX, 37 FSDP_WRAPPED_MODULE, 38 HandleTrainingState, 39 TrainingState, 40) 41from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo 42from torch.distributed.fsdp._init_utils import ( 43 _check_orig_params_flattened, 44 _init_buffer_state, 45 _init_core_state, 46 _init_device_handle, 47 _init_extension, 48 _init_ignored_module_states, 49 _init_param_handle_from_module, 50 _init_prefetching_state, 51 _init_process_group_state, 52 _init_runtime_state, 53 _init_state_dict_state, 54 HYBRID_SHARDING_STRATEGIES, 55 ProcessGroupType, 56) 57from torch.distributed.fsdp._runtime_utils import ( 58 _get_fsdp_root_states, 59 _is_fsdp_root, 60 _lazy_init, 61 _post_forward, 62 _post_forward_reshard, 63 _pre_forward, 64 _pre_forward_unshard, 65 _root_pre_forward, 66 _unshard, 67 _wait_for_computation_stream, 68) 69from torch.distributed.fsdp._wrap_utils import _auto_wrap 70from torch.distributed.fsdp.api import ( 71 BackwardPrefetch, 72 CPUOffload, 73 FullOptimStateDictConfig, 74 FullStateDictConfig, 75 LocalOptimStateDictConfig, 76 LocalStateDictConfig, 77 MixedPrecision, 78 OptimStateDictConfig, 79 ShardedOptimStateDictConfig, 80 ShardedStateDictConfig, 81 ShardingStrategy, 82 StateDictConfig, 83 StateDictSettings, 84 StateDictType, 85) 86from torch.distributed.tensor import DeviceMesh 87from torch.distributed.utils import _p_assert 88 89from ._flat_param import FlatParameter, FlatParamHandle 90from ._optim_utils import ( 91 _flatten_optim_state_dict, 92 _get_param_id_to_param_from_optim_input, 93 _get_param_key_to_param, 94 _get_param_to_param_id_from_optim_input, 95 _get_param_to_param_key, 96 _optim_state_dict, 97 _rekey_sharded_optim_state_dict, 98 _set_optim_use_dtensor, 99) 100from ._state_dict_utils import _register_all_state_dict_hooks 101from ._unshard_param_utils import ( 102 _deregister_orig_params, 103 _register_flat_param, 104 _register_orig_params, 105 _unshard_params, 106 _unshard_params_for_summon, 107) 108from .wrap import CustomPolicy, ModuleWrapPolicy 109 110 111__all__ = [ 112 "FullyShardedDataParallel", 113 "OptimStateKeyType", 114] 115 116 117FLAT_PARAM = "_flat_param" 118 119 120class OptimStateKeyType(Enum): 121 """Represents the type of key in an optimizer state-dict.""" 122 123 PARAM_NAME = auto() 124 PARAM_ID = auto() 125 126 127class FullyShardedDataParallel(nn.Module, _FSDPState): 128 """A wrapper for sharding module parameters across data parallel workers. 129 130 This is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. 131 FullyShardedDataParallel is commonly shortened to FSDP. 132 133 .. _`Xu et al.`: https://arxiv.org/abs/2004.13336 134 .. _DeepSpeed: https://www.deepspeed.ai/ 135 136 To understand FSDP internals, refer to the 137 :ref:`fsdp_notes`. 138 139 Example:: 140 141 >>> # xdoctest: +SKIP("undefined variables") 142 >>> import torch 143 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 144 >>> torch.cuda.set_device(device_id) 145 >>> sharded_module = FSDP(my_module) 146 >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) 147 >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) 148 >>> loss = x.sum() 149 >>> loss.backward() 150 >>> optim.step() 151 152 Using FSDP involves wrapping your module and then initializing your 153 optimizer after. This is required since FSDP changes the parameter 154 variables. 155 156 When setting up FSDP, you need to consider the destination CUDA 157 device. If the device has an ID (``dev_id``), you have three options: 158 159 * Place the module on that device 160 * Set the device using ``torch.cuda.set_device(dev_id)`` 161 * Pass ``dev_id`` into the ``device_id`` constructor argument. 162 163 This ensures that the FSDP instance's compute device is the 164 destination device. For option 1 and 3, the FSDP initialization 165 always occurs on GPU. For option 2, the FSDP initialization 166 happens on module's current device, which may be a CPU. 167 168 If you're using the ``sync_module_states=True`` flag, you need to 169 ensure that the module is on a GPU or use the ``device_id`` 170 argument to specify a CUDA device that FSDP will move the module 171 to in the FSDP constructor. This is necessary because 172 ``sync_module_states=True`` requires GPU communication. 173 174 FSDP also takes care of moving input tensors to the forward method 175 to the GPU compute device, so you don't need to manually move them 176 from CPU. 177 178 For ``use_orig_params=True``, 179 ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded 180 parameters, not the sharded parameters after forward, unlike 181 ``ShardingStrategy.FULL_SHARD``. If you want 182 to inspect the gradients, you can use the ``summon_full_params`` 183 method with ``with_grads=True``. 184 185 With ``limit_all_gathers=True``, you may see a gap in the FSDP 186 pre-forward where the CPU thread is not issuing any kernels. This is 187 intentional and shows the rate limiter in effect. Synchronizing the CPU 188 thread in that way prevents over-allocating memory for subsequent 189 all-gathers, and it should not actually delay GPU kernel execution. 190 191 FSDP replaces managed modules' parameters with ``torch.Tensor`` 192 views during forward and backward computation for autograd-related 193 reasons. If your module's forward relies on saved references to 194 the parameters instead of reacquiring the references each 195 iteration, then it will not see FSDP's newly created views, 196 and autograd will not work correctly. 197 198 Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` 199 with the sharding process group being intra-node and the 200 replication process group being inter-node, setting 201 ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over 202 the replication process group for some cluster setups. 203 204 **Limitations** 205 206 There are several limitations to be aware of when using FSDP: 207 208 * FSDP currently does not support gradient accumulation outside 209 ``no_sync()`` when using CPU offloading. This is because FSDP 210 uses the newly-reduced gradient instead of accumulating with any 211 existing gradient, which can lead to incorrect results. 212 213 * FSDP does not support running the forward pass of a submodule 214 that is contained in an FSDP instance. This is because the 215 submodule's parameters will be sharded, but the submodule itself 216 is not an FSDP instance, so its forward pass will not all-gather 217 the full parameters appropriately. 218 219 * FSDP does not work with double backwards due to the way it 220 registers backward hooks. 221 222 * FSDP has some constraints when freezing parameters. 223 For ``use_orig_params=False``, each FSDP instance must manage 224 parameters that are all frozen or all non-frozen. For 225 ``use_orig_params=True``, FSDP supports mixing frozen and 226 non-frozen parameters, but it's recommended to avoid doing so to 227 prevent higher than expected gradient memory usage. 228 229 * As of PyTorch 1.12, FSDP offers limited support for shared 230 parameters. If enhanced shared parameter support is needed for 231 your use case, please post in 232 `this issue <https://github.com/pytorch/pytorch/issues/77724>`__. 233 234 * You should avoid modifying the parameters between forward and 235 backward without using the ``summon_full_params`` context, as 236 the modifications may not persist. 237 238 Args: 239 module (nn.Module): 240 This is the module to be wrapped with FSDP. 241 process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]): 242 This is the process group over which the model is sharded and thus 243 the one used for FSDP's all-gather and reduce-scatter collective 244 communications. If ``None``, then FSDP uses the default process 245 group. For hybrid sharding strategies such as 246 ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of 247 process groups, representing the groups over which to shard and 248 replicate, respectively. If ``None``, then FSDP constructs process 249 groups for the user to shard intra-node and replicate inter-node. 250 (Default: ``None``) 251 sharding_strategy (Optional[ShardingStrategy]): 252 This configures the sharding strategy, which may trade off memory 253 saving and communication overhead. See :class:`ShardingStrategy` 254 for details. (Default: ``FULL_SHARD``) 255 cpu_offload (Optional[CPUOffload]): 256 This configures CPU offloading. If this is set to ``None``, then 257 no CPU offloading happens. See :class:`CPUOffload` for details. 258 (Default: ``None``) 259 auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): 260 This specifies a policy to apply FSDP to submodules of ``module``, 261 which is needed for communication and computation overlap and thus 262 affects performance. If ``None``, then FSDP only applies to 263 ``module``, and users should manually apply FSDP to parent modules 264 themselves (proceeding bottom-up). For convenience, this accepts 265 ``ModuleWrapPolicy`` directly, which allows users to specify the 266 module classes to wrap (e.g. the transformer block). Otherwise, 267 this should be a callable that takes in three arguments 268 ``module: nn.Module``, ``recurse: bool``, and 269 ``nonwrapped_numel: int`` and should return a ``bool`` specifying 270 whether the passed-in ``module`` should have FSDP applied if 271 ``recurse=False`` or if the traversal should continue into the 272 module's subtree if ``recurse=True``. Users may add additional 273 arguments to the callable. The ``size_based_auto_wrap_policy`` in 274 ``torch.distributed.fsdp.wrap.py`` gives an example callable that 275 applies FSDP to a module if the parameters in its subtree exceed 276 100M numel. We recommend printing the model after applying FSDP 277 and adjusting as needed. 278 279 Example:: 280 281 >>> def custom_auto_wrap_policy( 282 >>> module: nn.Module, 283 >>> recurse: bool, 284 >>> nonwrapped_numel: int, 285 >>> # Additional custom arguments 286 >>> min_num_params: int = int(1e8), 287 >>> ) -> bool: 288 >>> return nonwrapped_numel >= min_num_params 289 >>> # Configure a custom `min_num_params` 290 >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) 291 292 backward_prefetch (Optional[BackwardPrefetch]): 293 This configures explicit backward prefetching of all-gathers. If 294 ``None``, then FSDP does not backward prefetch, and there is no 295 communication and computation overlap in the backward pass. See 296 :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``) 297 mixed_precision (Optional[MixedPrecision]): 298 This configures native mixed precision for FSDP. If this is set to 299 ``None``, then no mixed precision is used. Otherwise, parameter, 300 buffer, and gradient reduction dtypes can be set. See 301 :class:`MixedPrecision` for details. (Default: ``None``) 302 ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose 303 own parameters and child modules' parameters and buffers are 304 ignored by this instance. None of the modules directly in 305 ``ignored_modules`` should be :class:`FullyShardedDataParallel` 306 instances, and any child modules that are already-constructed 307 :class:`FullyShardedDataParallel` instances will not be ignored if 308 they are nested under this instance. This argument may be used to 309 avoid sharding specific parameters at module granularity when using an 310 ``auto_wrap_policy`` or if parameters' sharding is not managed by 311 FSDP. (Default: ``None``) 312 param_init_fn (Optional[Callable[[nn.Module], None]]): 313 A ``Callable[torch.nn.Module] -> None`` that 314 specifies how modules that are currently on the meta device should 315 be initialized onto an actual device. As of v1.12, FSDP detects 316 modules with parameters or buffers on meta device via ``is_meta`` 317 and either applies ``param_init_fn`` if specified or calls 318 ``nn.Module.reset_parameters()`` otherwise. For both cases, the 319 implementation should *only* initialize the parameters/buffers of 320 the module, not those of its submodules. This is to avoid 321 re-initialization. In addition, FSDP also supports deferred 322 initialization via torchdistX's (https://github.com/pytorch/torchdistX) 323 ``deferred_init()`` API, where the deferred modules are initialized 324 by calling ``param_init_fn`` if specified or torchdistX's default 325 ``materialize_module()`` otherwise. If ``param_init_fn`` is 326 specified, then it is applied to all meta-device modules, meaning 327 that it should probably case on the module type. FSDP calls the 328 initialization function before parameter flattening and sharding. 329 330 Example:: 331 332 >>> # xdoctest: +SKIP("undefined variables") 333 >>> module = MyModule(device="meta") 334 >>> def my_init_fn(module: nn.Module): 335 >>> # E.g. initialize depending on the module type 336 >>> ... 337 >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) 338 >>> print(next(fsdp_model.parameters()).device) # current CUDA device 339 >>> # With torchdistX 340 >>> module = deferred_init.deferred_init(MyModule, device="cuda") 341 >>> # Will initialize via deferred_init.materialize_module(). 342 >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) 343 344 device_id (Optional[Union[int, torch.device]]): An ``int`` or 345 ``torch.device`` giving the CUDA device on which FSDP 346 initialization takes place, including the module initialization 347 if needed and the parameter sharding. This should be specified to 348 improve initialization speed if ``module`` is on CPU. If the 349 default CUDA device was set (e.g. via ``torch.cuda.set_device``), 350 then the user may pass ``torch.cuda.current_device`` to this. 351 (Default: ``None``) 352 sync_module_states (bool): If ``True``, then each FSDP module will 353 broadcast module parameters and buffers from rank 0 to ensure that 354 they are replicated across ranks (adding communication overhead to 355 this constructor). This can help load ``state_dict`` checkpoints 356 via ``load_state_dict`` in a memory efficient way. See 357 :class:`FullStateDictConfig` for an example of this. (Default: 358 ``False``) 359 forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches 360 the next forward-pass all-gather before the current forward 361 computation. This is only useful for CPU-bound workloads, in which 362 case issuing the next all-gather earlier may improve overlap. This 363 should only be used for static-graph models since the prefetching 364 follows the first iteration's execution order. (Default: ``False``) 365 limit_all_gathers (bool): If ``True``, then FSDP explicitly 366 synchronizes the CPU thread to ensure GPU memory usage from only 367 *two* consecutive FSDP instances (the current instance running 368 computation and the next instance whose all-gather is prefetched). 369 If ``False``, then FSDP allows the CPU thread to issue all-gathers 370 without any extra synchronization. (Default: ``True``) We often 371 refer to this feature as the "rate limiter". This flag should only 372 be set to ``False`` for specific CPU-bound workloads with low 373 memory pressure in which case the CPU thread can aggressively issue 374 all kernels without concern for the GPU memory usage. 375 use_orig_params (bool): Setting this to ``True`` has FSDP use 376 ``module`` 's original parameters. FSDP exposes those original 377 parameters to the user via :meth:`nn.Module.named_parameters` 378 instead of FSDP's internal :class:`FlatParameter` s. This means 379 that the optimizer step runs on the original parameters, enabling 380 per-original-parameter hyperparameters. FSDP preserves the original 381 parameter variables and manipulates their data between unsharded 382 and sharded forms, where they are always views into the underlying 383 unsharded or sharded :class:`FlatParameter`, respectively. With the 384 current algorithm, the sharded form is always 1D, losing the 385 original tensor structure. An original parameter may have all, 386 some, or none of its data present for a given rank. In the none 387 case, its data will be like a size-0 empty tensor. Users should not 388 author programs relying on what data is present for a given 389 original parameter in its sharded form. ``True`` is required to 390 use ``torch.compile()``. Setting this to ``False`` exposes FSDP's 391 internal :class:`FlatParameter` s to the user via 392 :meth:`nn.Module.named_parameters`. (Default: ``False``) 393 ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]): 394 Ignored parameters or modules that will not be managed by this FSDP 395 instance, meaning that the parameters are not sharded and their 396 gradients are not reduced across ranks. This argument unifies with 397 the existing ``ignored_modules`` argument, and we may deprecate 398 ``ignored_modules`` soon. For backward compatibility, we keep both 399 ``ignored_states`` and `ignored_modules``, but FSDP only allows one 400 of them to be specified as not ``None``. 401 device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an altenative to 402 process_group. When device_mesh is passed, FSDP will use the underlying process 403 groups for all-gather and reduce-scatter collective communications. Therefore, 404 these two args need to be mutually exclusive. For hybrid sharding strategies such as 405 ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead 406 of a tuple of process groups. For 2D FSDP + TP, users are required to pass in 407 device_mesh instead of process_group. For more DeviceMesh info, please visit: 408 https://pytorch.org/tutorials/recipes/distributed_device_mesh.html 409 """ 410 411 def __init__( 412 self, 413 module: nn.Module, 414 process_group: ProcessGroupType = None, 415 sharding_strategy: Optional[ShardingStrategy] = None, 416 cpu_offload: Optional[CPUOffload] = None, 417 auto_wrap_policy: Optional[ 418 Union[Callable, ModuleWrapPolicy, CustomPolicy] 419 ] = None, 420 backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, 421 mixed_precision: Optional[MixedPrecision] = None, 422 ignored_modules: Optional[Iterable[torch.nn.Module]] = None, 423 param_init_fn: Optional[Callable[[nn.Module], None]] = None, 424 device_id: Optional[Union[int, torch.device]] = None, 425 sync_module_states: bool = False, 426 forward_prefetch: bool = False, 427 limit_all_gathers: bool = True, 428 use_orig_params: bool = False, 429 ignored_states: Union[ 430 Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] 431 ] = None, 432 device_mesh: Optional[DeviceMesh] = None, 433 ): 434 torch._C._log_api_usage_once("torch.distributed.fsdp") 435 super().__init__() 436 if isinstance(module, (nn.ModuleList, nn.ModuleDict)): 437 warnings.warn( 438 "FSDP will not all-gather parameters for containers that do " 439 f"not implement forward: {module}", 440 stacklevel=2, 441 ) 442 _init_ignored_module_states(self, module, ignored_modules, ignored_states) 443 _init_device_handle(self, module, self._ignored_params, device_id) 444 445 # Add module annotations for Dynamo support (see function for details) 446 _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) 447 448 # Initializes self.process_group, along with rank and world size. This will 449 # also set another attribute, _inter_node_pg, to control the process group 450 # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. 451 # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up 452 # the same process group state as the root FSDP module. 453 self._device_mesh = device_mesh 454 _init_process_group_state( 455 self, 456 process_group, 457 sharding_strategy, 458 auto_wrap_policy, 459 device_mesh, 460 ) 461 if auto_wrap_policy is not None: 462 root_kwargs = { 463 "process_group": process_group, 464 "sharding_strategy": sharding_strategy, 465 "cpu_offload": cpu_offload, 466 "backward_prefetch": backward_prefetch, 467 "mixed_precision": mixed_precision, 468 "param_init_fn": param_init_fn, 469 "device_id": device_id, 470 "sync_module_states": sync_module_states, 471 "forward_prefetch": forward_prefetch, 472 "limit_all_gathers": limit_all_gathers, 473 "use_orig_params": use_orig_params, 474 "ignored_states": self._ignored_params, 475 "device_mesh": device_mesh, 476 } 477 if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: 478 # Share root process groups with children to maintain 479 # the invariant that all FSDP modules will have the same 480 # process groups. 481 root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) 482 483 _auto_wrap( 484 module, 485 auto_wrap_policy, 486 self._ignored_modules, 487 self._ignored_params, 488 root_kwargs, 489 FullyShardedDataParallel, 490 ) 491 492 backward_prefetch_limit = 1 493 forward_prefetch_limit = 1 494 _init_core_state( 495 self, 496 sharding_strategy, 497 mixed_precision, 498 cpu_offload, 499 limit_all_gathers, 500 use_orig_params, 501 backward_prefetch_limit, 502 forward_prefetch_limit, 503 ) 504 _init_runtime_state(self) 505 _init_prefetching_state(self, backward_prefetch, forward_prefetch) 506 _init_buffer_state(self, module) 507 # extension needs to be set before `_init_param_handle_from_module()` 508 _init_extension(self, device_mesh) 509 _init_param_handle_from_module( 510 self, 511 module, 512 device_id, 513 param_init_fn, 514 sync_module_states, 515 ) 516 self._fsdp_wrapped_module = module 517 if not use_orig_params: 518 _check_orig_params_flattened(self, self._ignored_params) 519 _register_flat_param(self, self) 520 521 # `_state_dict_type` controls the `state_dict()` behavior, which is 522 # implemented using post-save and pre-load hooks 523 _init_state_dict_state(self) 524 _register_all_state_dict_hooks(self) 525 self._zero_scalar = None 526 527 @property 528 def module(self) -> nn.Module: 529 """Return the wrapped module.""" 530 # FSDP's `.module` must refer to the innermost wrapped module when 531 # composing with other module wrappers in order for state dict to work 532 if isinstance(self._fsdp_wrapped_module, ActivationWrapper): 533 return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE) 534 return self._fsdp_wrapped_module 535 536 @property 537 def _has_params(self) -> bool: 538 """Returns whether this FSDP instance manages any parameters.""" 539 return hasattr(self, "_handle") and self._handle is not None 540 541 @property 542 def _flat_param(self) -> Optional[FlatParameter]: 543 return self._handle.flat_param if self._handle else None 544 545 def __getattr__(self, name: str) -> Any: 546 """Forward missing attributes to the wrapped module.""" 547 try: 548 return super().__getattr__(name) # defer to nn.Module's logic 549 except AttributeError: 550 return getattr(self._fsdp_wrapped_module, name) 551 552 def __getitem__(self, key: int) -> Any: 553 """Forward indexing calls in case the module is an ``nn.Sequential``.""" 554 if hasattr(self, FSDP_WRAPPED_MODULE): 555 return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] 556 return super().__getitem__(key) 557 558 def check_is_root(self) -> bool: 559 """Check if this instance is a root FSDP module.""" 560 return _is_fsdp_root(self, self) 561 562 @staticmethod 563 def fsdp_modules( 564 module: nn.Module, 565 root_only: bool = False, 566 ) -> List["FullyShardedDataParallel"]: 567 """Return all nested FSDP instances. 568 569 This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. 570 571 Args: 572 module (torch.nn.Module): Root module, which may or may not be an 573 ``FSDP`` module. 574 root_only (bool): Whether to return only FSDP root modules. 575 (Default: ``False``) 576 577 Returns: 578 List[FullyShardedDataParallel]: FSDP modules that are nested in 579 the input ``module``. 580 """ 581 if root_only: 582 return _get_fsdp_root_states(module) 583 return traversal_utils._get_fsdp_states(module) 584 585 def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": 586 r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. 587 588 Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). 589 590 Compared to ``torch.nn.Module.apply``, this version additionally gathers 591 the full parameters before applying ``fn``. It should not be called from 592 within another ``summon_full_params`` context. 593 594 Args: 595 fn (:class:`Module` -> None): function to be applied to each submodule 596 597 Returns: 598 Module: self 599 """ 600 uninitialized = self._is_root is None 601 self._assert_state(TrainingState.IDLE) 602 # Use `_unshard_params_for_summon()` with `recurse=False` instead of 603 # `_unshard_fsdp_state_params()` directly to perform lazy 604 # initialization, which is needed to initialize `FlatParameter` 605 # parameter attributes as required by the unshard logic 606 with _unshard_params_for_summon( 607 self, 608 self, 609 writeback=True, 610 rank0_only=False, 611 offload_to_cpu=False, 612 with_grads=False, 613 ): 614 ret = super().apply(fn) 615 616 # Reset lazy init called in `_unshard_params_for_summon()` since 617 # `apply()` may have been called on FSDP instance that is not truly a 618 # root, in which case it will be incorrectly marked as one. 619 if uninitialized and self._is_root: 620 for module in traversal_utils._get_fsdp_states(self): 621 module._reset_lazy_init() 622 623 return ret 624 625 def _mixed_precision_enabled_for_buffers(self) -> bool: 626 """Return whether the user explicitly enabled buffer mixed precision. 627 628 NOTE: Unlike parameters and gradient reduction, buffer mixed precision 629 is applied at the FSDP instance level, not the ``FlatParameter`` level, 630 which may be different for the composable code path. 631 """ 632 return self.mixed_precision.buffer_dtype is not None 633 634 def _low_precision_hook_enabled(self) -> bool: 635 """Whether a low precision hook is registered or not.""" 636 return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS 637 638 def _reset_lazy_init(self) -> None: 639 """Reset instance so :func:`_lazy_init` will run on the next forward.""" 640 self._is_root: Optional[bool] = None 641 642 @staticmethod 643 def set_state_dict_type( 644 module: nn.Module, 645 state_dict_type: StateDictType, 646 state_dict_config: Optional[StateDictConfig] = None, 647 optim_state_dict_config: Optional[OptimStateDictConfig] = None, 648 ) -> StateDictSettings: 649 """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. 650 651 Also takes (optional) configuration for the model's and optimizer's state dict. 652 The target module does not have to be a FSDP module. If the target 653 module is a FSDP module, its ``state_dict_type`` will also be changed. 654 655 .. note:: This API should be called for only the top-level (root) 656 module. 657 658 .. note:: This API enables users to transparently use the conventional 659 ``state_dict`` API to take model checkpoints in cases where the 660 root FSDP module is wrapped by another ``nn.Module``. For example, 661 the following will ensure ``state_dict`` is called on all non-FSDP 662 instances, while dispatching into `sharded_state_dict` implementation 663 for FSDP: 664 665 Example:: 666 667 >>> # xdoctest: +SKIP("undefined variables") 668 >>> model = DDP(FSDP(...)) 669 >>> FSDP.set_state_dict_type( 670 >>> model, 671 >>> StateDictType.SHARDED_STATE_DICT, 672 >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), 673 >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), 674 >>> ) 675 >>> param_state_dict = model.state_dict() 676 >>> optim_state_dict = FSDP.optim_state_dict(model, optim) 677 678 Args: 679 module (torch.nn.Module): Root module. 680 state_dict_type (StateDictType): the desired ``state_dict_type`` to set. 681 state_dict_config (Optional[StateDictConfig]): the configuration for the 682 target ``state_dict_type``. 683 optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration 684 for the optimizer state dict. 685 686 Returns: 687 A StateDictSettings that include the previous state_dict type and 688 configuration for the module. 689 """ 690 warnings.warn( 691 "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " 692 "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " 693 "which can support different parallelisms, FSDP1, FSDP2, DDP. " 694 "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" 695 "#torch.distributed.checkpoint.state_dict.get_state_dict ." 696 "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", 697 FutureWarning, 698 ) 699 _state_dict_type_to_config = { 700 StateDictType.FULL_STATE_DICT: FullStateDictConfig, 701 StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, 702 StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, 703 } 704 _optim_state_dict_type_to_config = { 705 StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, 706 StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, 707 StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, 708 } 709 710 # Use the default config if a state_dict config is not set. 711 state_dict_config_type = _state_dict_type_to_config[state_dict_type] 712 optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] 713 if state_dict_config is None: 714 state_dict_config = state_dict_config_type() 715 if optim_state_dict_config is None: 716 optim_state_dict_config = optim_state_dict_config_type() 717 if state_dict_config_type != type(state_dict_config): 718 raise RuntimeError( 719 f"Expected state_dict_config of type {state_dict_config_type} " 720 f"but got {type(state_dict_config)}" 721 ) 722 if optim_state_dict_config_type != type(optim_state_dict_config): 723 raise RuntimeError( 724 f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " 725 f"but got {type(optim_state_dict_config)}" 726 ) 727 728 # Set the state_dict type and configurations. 729 prev_state_dict_type = None 730 prev_state_dict_config = None 731 prev_optim_state_dict_config = None 732 for submodule in traversal_utils._get_fsdp_states(module): 733 if prev_state_dict_type is None: 734 prev_state_dict_type = submodule._state_dict_type 735 else: 736 assert ( 737 prev_state_dict_type == submodule._state_dict_type 738 ), "All FSDP modules should have the same state_dict_type." 739 if prev_state_dict_config is None: 740 prev_state_dict_config = submodule._state_dict_config 741 else: 742 assert isinstance( 743 submodule._state_dict_config, type(prev_state_dict_config) 744 ), "All FSDP modules must have the same type of state_dict_config." 745 if prev_optim_state_dict_config is None: 746 prev_optim_state_dict_config = submodule._optim_state_dict_config 747 else: 748 assert isinstance( 749 submodule._optim_state_dict_config, 750 type(prev_optim_state_dict_config), 751 ), "All FSDP modules must have the same type of optim_state_dict_config." 752 753 submodule._state_dict_type = state_dict_type 754 submodule._state_dict_config = state_dict_config 755 submodule._optim_state_dict_config = optim_state_dict_config 756 757 return StateDictSettings( 758 prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config 759 ) 760 761 @staticmethod 762 def get_state_dict_type(module: nn.Module) -> StateDictSettings: 763 """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``. 764 765 The target module does not have to be an FSDP module. 766 767 Returns: 768 A ``StateDictSettings`` containing the state_dict_type and 769 state_dict / optim_state_dict configs that are currently set. 770 771 Raises: 772 ``AssertionError`` if the ``StateDictSettings`` for different 773 FSDP submodules differ. 774 """ 775 state_dict_settings: Optional[StateDictSettings] = None 776 for submodule in FullyShardedDataParallel.fsdp_modules(module): 777 if state_dict_settings is None: 778 state_dict_settings = StateDictSettings( 779 state_dict_type=submodule._state_dict_type, 780 state_dict_config=submodule._state_dict_config, 781 optim_state_dict_config=submodule._optim_state_dict_config, 782 ) 783 _set_optim_use_dtensor(submodule, state_dict_settings) 784 else: 785 submodule_settings = StateDictSettings( 786 submodule._state_dict_type, 787 submodule._state_dict_config, 788 submodule._optim_state_dict_config, 789 ) 790 assert state_dict_settings == submodule_settings, ( 791 "All FSDP modules must have the same state dict settings." 792 f"Got {submodule_settings} and {state_dict_settings}." 793 ) 794 _set_optim_use_dtensor(submodule, submodule_settings) 795 return state_dict_settings 796 797 @staticmethod 798 @contextlib.contextmanager 799 def state_dict_type( 800 module: nn.Module, 801 state_dict_type: StateDictType, 802 state_dict_config: Optional[StateDictConfig] = None, 803 optim_state_dict_config: Optional[OptimStateDictConfig] = None, 804 ) -> Generator: 805 """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. 806 807 This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of 808 :meth:`set_state_dict_type` for the detail. 809 810 Example:: 811 812 >>> # xdoctest: +SKIP("undefined variables") 813 >>> model = DDP(FSDP(...)) 814 >>> with FSDP.state_dict_type( 815 >>> model, 816 >>> StateDictType.SHARDED_STATE_DICT, 817 >>> ): 818 >>> checkpoint = model.state_dict() 819 820 Args: 821 module (torch.nn.Module): Root module. 822 state_dict_type (StateDictType): the desired ``state_dict_type`` to set. 823 state_dict_config (Optional[StateDictConfig]): the model ``state_dict`` 824 configuration for the target ``state_dict_type``. 825 optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer 826 ``state_dict`` configuration for the target ``state_dict_type``. 827 """ 828 prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( 829 module, 830 state_dict_type, 831 state_dict_config, 832 optim_state_dict_config, 833 ) 834 yield 835 FullyShardedDataParallel.set_state_dict_type( 836 module, 837 prev_state_dict_settings.state_dict_type, 838 prev_state_dict_settings.state_dict_config, 839 prev_state_dict_settings.optim_state_dict_config, 840 ) 841 842 def forward(self, *args: Any, **kwargs: Any) -> Any: 843 """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" 844 handle = self._handle 845 with torch.autograd.profiler.record_function( 846 "FullyShardedDataParallel.forward" 847 ): 848 args, kwargs = _root_pre_forward(self, self, args, kwargs) 849 unused = None 850 args, kwargs = _pre_forward( 851 self, 852 handle, 853 _pre_forward_unshard, 854 self._fsdp_wrapped_module, 855 args, 856 kwargs, 857 ) 858 if handle: 859 _p_assert( 860 handle.flat_param.device == self.compute_device, 861 "Expected `FlatParameter` to be on the compute device " 862 f"{self.compute_device} but got {handle.flat_param.device}", 863 ) 864 output = self._fsdp_wrapped_module(*args, **kwargs) 865 return _post_forward( 866 self, handle, _post_forward_reshard, self, unused, output 867 ) 868 869 @staticmethod 870 @contextlib.contextmanager 871 def summon_full_params( 872 module: nn.Module, 873 recurse: bool = True, 874 writeback: bool = True, 875 rank0_only: bool = False, 876 offload_to_cpu: bool = False, 877 with_grads: bool = False, 878 ) -> Generator: 879 r"""Expose full params for FSDP instances with this context manager. 880 881 Can be useful *after* forward/backward for a model to get 882 the params for additional processing or checking. It can take a non-FSDP 883 module and will summon full params for all contained FSDP modules as 884 well as their children, depending on the ``recurse`` argument. 885 886 .. note:: This can be used on inner FSDPs. 887 .. note:: This can *not* be used within a forward or backward pass. Nor 888 can forward and backward be started from within this context. 889 .. note:: Parameters will revert to their local shards after the context 890 manager exits, storage behavior is the same as forward. 891 .. note:: The full parameters can be modified, but only the portion 892 corresponding to the local param shard will persist after the 893 context manager exits (unless ``writeback=False``, in which case 894 changes will be discarded). In the case where FSDP does not shard 895 the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` 896 config, the modification is persisted regardless of ``writeback``. 897 .. note:: This method works on modules which are not FSDP themselves but 898 may contain multiple independent FSDP units. In that case, the given 899 arguments will apply to all contained FSDP units. 900 901 .. warning:: Note that ``rank0_only=True`` in conjunction with 902 ``writeback=True`` is not currently supported and will raise an 903 error. This is because model parameter shapes would be different 904 across ranks within the context, and writing to them can lead to 905 inconsistency across ranks when the context is exited. 906 907 .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will 908 result in full parameters being redundantly copied to CPU memory for 909 GPUs that reside on the same machine, which may incur the risk of 910 CPU OOM. It is recommended to use ``offload_to_cpu`` with 911 ``rank0_only=True``. 912 913 Args: 914 recurse (bool, Optional): recursively summon all params for nested 915 FSDP instances (default: True). 916 writeback (bool, Optional): if ``False``, modifications to params are 917 discarded after the context manager exits; 918 disabling this can be slightly more efficient (default: True) 919 rank0_only (bool, Optional): if ``True``, full parameters are 920 materialized on only global rank 0. This means that within the 921 context, only rank 0 will have full parameters and the other 922 ranks will have sharded parameters. Note that setting 923 ``rank0_only=True`` with ``writeback=True`` is not supported, 924 as model parameter shapes will be different across ranks 925 within the context, and writing to them can lead to 926 inconsistency across ranks when the context is exited. 927 offload_to_cpu (bool, Optional): If ``True``, full parameters are 928 offloaded to CPU. Note that this offloading currently only 929 occurs if the parameter is sharded (which is only not the case 930 for world_size = 1 or ``NO_SHARD`` config). It is recommended 931 to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid 932 redundant copies of model parameters being offloaded to the same CPU memory. 933 with_grads (bool, Optional): If ``True``, gradients are also 934 unsharded with the parameters. Currently, this is only 935 supported when passing ``use_orig_params=True`` to the FSDP 936 constructor and ``offload_to_cpu=False`` to this method. 937 (Default: ``False``) 938 """ 939 with _unshard_params( 940 module, recurse, writeback, rank0_only, offload_to_cpu, with_grads 941 ): 942 yield 943 944 @contextlib.contextmanager 945 def _deregister_orig_params_ctx(self): 946 """Deregister the original parameters and expose the :class:`FlatParameter`. 947 948 If a :class:`FlatParameter` is sharded, then 949 this refreshes the sharded views before exiting. This method should 950 only be called when using the original parameters. 951 """ 952 _p_assert( 953 self._use_orig_params, 954 "`_deregister_orig_params_ctx()` should only be called when " 955 "`_use_orig_params=True`", 956 ) 957 for fsdp_module in traversal_utils._get_fsdp_states(self): 958 _deregister_orig_params(fsdp_module, fsdp_module) 959 try: 960 yield 961 finally: 962 for fsdp_module in traversal_utils._get_fsdp_states(self): 963 _register_orig_params(fsdp_module, fsdp_module) 964 965 def _apply(self, *args, **kwargs): 966 """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``.""" 967 # When using the original parameters: Since (1) the `FlatParameter`s 968 # own the storage and (2) `_apply()` is the subroutine underlying the 969 # most common storage-changing ops like `to()` and `cuda()`, we 970 # override `_apply()` to have the storage change directly performed on 971 # the `FlatParameter`s instead of applying to the original parameters 972 # and then writing back to the `FlatParameter`s. 973 context = ( 974 self._deregister_orig_params_ctx() 975 if self._use_orig_params 976 else contextlib.nullcontext() 977 ) 978 with context: 979 return super()._apply(*args, **kwargs) 980 981 def named_buffers( 982 self, 983 *args, 984 **kwargs, 985 ) -> Iterator[Tuple[str, torch.Tensor]]: 986 """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. 987 988 Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix 989 when inside the :meth:`summon_full_params` context manager. 990 """ 991 should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS 992 for buffer_name, buffer in super().named_buffers(*args, **kwargs): 993 if should_clean_name: 994 # Remove any instances of the FSDP-specific prefix; there can 995 # be multiple in the case of nested FSDP modules 996 buffer_name = buffer_name.replace(FSDP_PREFIX, "") 997 yield (buffer_name, buffer) 998 999 def named_parameters( 1000 self, 1001 *args, 1002 **kwargs, 1003 ) -> Iterator[Tuple[str, torch.nn.Parameter]]: 1004 """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. 1005 1006 Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix 1007 when inside the :meth:`summon_full_params` context manager. 1008 """ 1009 should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS 1010 for param_name, param in super().named_parameters(*args, **kwargs): 1011 if should_clean_name: 1012 # Remove any instances of the FSDP-specific prefix; there can 1013 # be multiple in the case of nested FSDP modules 1014 param_name = param_name.replace(FSDP_PREFIX, "") 1015 yield (param_name, param) 1016 1017 def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: 1018 """Assert we are in the given state.""" 1019 # Since assert can be turned off and this error checking 1020 # is really important, we use explicit error checking 1021 # and raise a ValueError if needed. 1022 if isinstance(state, TrainingState): 1023 state = [state] 1024 if self.training_state not in state: 1025 msg = ( 1026 f"expected to be in states {state} but current state " 1027 f"is {self.training_state}" 1028 ) 1029 # In case we are failing in the context of autograd hook, asserting 1030 # may not generate useful msg. So, let's print it to be sure. 1031 if self.rank == 0: 1032 print(f"Asserting FSDP instance is: {self}") 1033 print(f"ERROR: {msg}") 1034 traceback.print_stack() 1035 raise ValueError(msg) 1036 1037 @contextmanager 1038 def no_sync(self) -> Generator: 1039 """Disable gradient synchronizations across FSDP instances. 1040 1041 Within this context, gradients will be accumulated in module 1042 variables, which will later be synchronized in the first 1043 forward-backward pass after exiting the context. This should only be 1044 used on the root FSDP instance and will recursively apply to all 1045 children FSDP instances. 1046 1047 .. note:: This likely results in higher memory usage because FSDP will 1048 accumulate the full model gradients (instead of gradient shards) 1049 until the eventual sync. 1050 1051 .. note:: When used with CPU offloading, the gradients will not be 1052 offloaded to CPU when inside the context manager. Instead, they 1053 will only be offloaded right after the eventual sync. 1054 """ 1055 _lazy_init(self, self) 1056 if not self._is_root: 1057 raise RuntimeError( 1058 "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module." 1059 ) 1060 self._assert_state(TrainingState.IDLE) 1061 old_flags = [] 1062 for m in self.modules(): 1063 if isinstance(m, FullyShardedDataParallel): 1064 old_flags.append((m, m._sync_gradients)) 1065 m._sync_gradients = False 1066 try: 1067 yield 1068 finally: 1069 for m, old_flag in old_flags: 1070 assert not m._sync_gradients, ( 1071 "`_sync_gradients` was incorrectly set to " 1072 "`True` while in the `no_sync()` context manager" 1073 ) 1074 m._sync_gradients = old_flag 1075 1076 @torch.no_grad() 1077 def clip_grad_norm_( 1078 self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 1079 ) -> torch.Tensor: 1080 """Clip the gradient norm of all parameters. 1081 1082 The norm is computed over all parameters' gradients as viewed as a single vector, and the 1083 gradients are modified in-place. 1084 1085 Args: 1086 max_norm (float or int): max norm of the gradients 1087 norm_type (float or int): type of the used p-norm. Can be ``'inf'`` 1088 for infinity norm. 1089 1090 Returns: 1091 Total norm of the parameters (viewed as a single vector). 1092 1093 If every FSDP instance uses ``NO_SHARD``, meaning that no 1094 gradients are sharded across ranks, then you may directly use 1095 :func:`torch.nn.utils.clip_grad_norm_`. 1096 1097 If at least some FSDP instance uses a sharded strategy (i.e. 1098 one other than ``NO_SHARD``), then you should use this method 1099 instead of :func:`torch.nn.utils.clip_grad_norm_` since this method 1100 handles the fact that gradients are sharded across ranks. 1101 1102 The total norm returned will have the "largest" dtype across 1103 all parameters/gradients as defined by PyTorch's type promotion 1104 semantics. For example, if *all* parameters/gradients use a low 1105 precision dtype, then the returned norm's dtype will be that low 1106 precision dtype, but if there exists at least one parameter/ 1107 gradient using FP32, then the returned norm's dtype will be FP32. 1108 1109 .. warning:: This needs to be called on all ranks since it uses 1110 collective communications. 1111 """ 1112 _lazy_init(self, self) 1113 if not self._is_root: 1114 raise RuntimeError( 1115 "`clip_grad_norm_()` should only be called on the root FSDP instance" 1116 ) 1117 if self._zero_scalar is None: 1118 self._zero_scalar = torch.tensor(0.0, device=self.compute_device) 1119 self._assert_state(TrainingState.IDLE) 1120 # If every FSDP instance uses `NO_SHARD`, then we can directly use 1121 # the normal `nn.utils` one targeting local gradients 1122 all_no_shard = all( 1123 not handle.uses_sharded_strategy for handle in self._all_handles 1124 ) 1125 if all_no_shard: 1126 return torch.nn.utils.clip_grad_norm_( 1127 self.parameters(), max_norm, norm_type 1128 ) 1129 # Otherwise, there exists some FSDP instance using a sharded strategy, 1130 # where sharded and non-sharded parameters must be handled separately 1131 max_norm = float(max_norm) 1132 norm_type = float(norm_type) 1133 sharded_params_set = set() 1134 nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed 1135 # Make sure to compute the local norm using lists for deterministic 1136 # iteration order and hence deterministic total norm computation 1137 sharded_params = [] 1138 nonsharded_params = [] 1139 grads: List[torch.Tensor] = [] 1140 for handle in self._all_handles: 1141 if handle.uses_sharded_strategy: 1142 target_set = sharded_params_set 1143 target_list = sharded_params 1144 else: 1145 target_set = nonsharded_params_set 1146 target_list = nonsharded_params 1147 if handle._use_orig_params: 1148 for param in handle.flat_param._params: 1149 if param not in target_set: 1150 target_set.add(param) 1151 target_list.append(param) 1152 if param.grad is not None: 1153 grads.append(param.grad) 1154 else: 1155 if handle.flat_param not in target_set: 1156 target_set.add(handle.flat_param) 1157 target_list.append(handle.flat_param) 1158 if handle.flat_param.grad is not None: 1159 grads.append(handle.flat_param.grad) 1160 for param in self.parameters(): 1161 not_fsdp_managed = ( 1162 param not in sharded_params_set and param not in nonsharded_params_set 1163 ) 1164 if not_fsdp_managed: 1165 nonsharded_params_set.add(param) 1166 nonsharded_params.append(param) 1167 if param.grad is not None: 1168 grads.append(param.grad) 1169 # Compute local norms (forced to be in FP32) 1170 local_sharded_norm = _get_grad_norm( 1171 sharded_params, norm_type, self._zero_scalar, self.compute_device 1172 ) 1173 local_nonsharded_norm = ( 1174 _get_grad_norm( 1175 nonsharded_params, norm_type, self._zero_scalar, self.compute_device 1176 ) 1177 if nonsharded_params 1178 else None 1179 ) 1180 # Reconstruct the total gradient norm depending on the norm type 1181 if norm_type == math.inf: 1182 total_norm = ( 1183 torch.maximum(local_sharded_norm, local_nonsharded_norm) 1184 if local_nonsharded_norm is not None 1185 else local_sharded_norm 1186 ) 1187 dist.all_reduce( 1188 total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group 1189 ) 1190 else: 1191 total_norm = local_sharded_norm**norm_type 1192 dist.all_reduce(total_norm, group=self.process_group) 1193 # All-reducing the local non-sharded norm would count it an extra 1194 # world-size-many times 1195 if local_nonsharded_norm is not None: 1196 total_norm += local_nonsharded_norm**norm_type 1197 total_norm = total_norm ** (1.0 / norm_type) 1198 if self.cpu_offload.offload_params: 1199 total_norm = total_norm.cpu() 1200 1201 clip_coef = max_norm / (total_norm + 1e-6) 1202 # Multiplying by the clamped coefficient is meaningless when it is 1203 # equal to 1, but it avoids the host-device sync that would result from 1204 # `if clip_coef < 1` 1205 clip_coef_clamped = torch.clamp(clip_coef, max=1.0) 1206 for grad in grads: 1207 grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) 1208 # Use the "largest" dtype by type promotion semantics to use the same 1209 # dtype as if we did not force local norm computation to be in FP32 1210 if len(grads) == 0: 1211 # If this rank has no gradients, then we must default to FP32 1212 # unless we use additional communication, which we prefer to avoid 1213 # since `clip_grad_norm_()` is called in the training loop 1214 warnings.warn( 1215 f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " 1216 "gradients -- returning the total norm in the default dtype " 1217 f"{total_norm.dtype}" 1218 ) # warn since this is generally unexpected 1219 return total_norm 1220 total_norm_dtype = functools.reduce( 1221 torch.promote_types, 1222 [grad.dtype for grad in grads], 1223 ) 1224 return total_norm.to(total_norm_dtype) 1225 1226 @staticmethod 1227 def _warn_optim_input(optim_input, *, stacklevel: int = 1): 1228 if optim_input is not None: 1229 warnings.warn( 1230 "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " 1231 "You may remove it from your code without changing its functionality.", 1232 FutureWarning, 1233 stacklevel=stacklevel + 1, 1234 ) 1235 1236 @staticmethod 1237 def _is_using_optim_input(optim_input, optim) -> bool: 1238 if optim_input is None and optim is None: 1239 # Use the default behavior of `optim_input`` 1240 return True 1241 if optim_input is not None: 1242 # Use the `optim_input` code path 1243 return True 1244 # Use the `optim` code path 1245 return False 1246 1247 @staticmethod 1248 def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): 1249 warnings.warn( 1250 f"``FullyShardedDataParallel.{curr}``is being deprecated and is " 1251 f"replaced by ``FullyShardedDataParallel.{new}``. " 1252 f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", 1253 FutureWarning, 1254 stacklevel=stacklevel + 1, 1255 ) 1256 1257 @staticmethod 1258 def _optim_state_dict_impl( 1259 model: torch.nn.Module, 1260 optim: torch.optim.Optimizer, 1261 optim_state_dict: Dict[str, Any], 1262 optim_input: Optional[ 1263 Union[ 1264 List[Dict[str, Any]], 1265 Iterable[torch.nn.Parameter], 1266 ] 1267 ] = None, 1268 rank0_only: bool = True, 1269 full_state_dict: bool = True, 1270 group: Optional[dist.ProcessGroup] = None, 1271 cpu_offload: bool = True, 1272 *, 1273 _stacklevel: int = 1, 1274 ) -> Dict[str, Any]: 1275 """Transform the state-dict of an optimizer corresponding to a sharded model. 1276 1277 This is the internal API that is used by all the optim_state_dict implementations. 1278 Given model, optim, the original optim_state_dict, this API removes the 1279 FSDP internal information and internal sharding from the optim_state_dict. 1280 """ 1281 if full_state_dict: 1282 FullyShardedDataParallel._warn_optim_input( 1283 optim_input, stacklevel=_stacklevel + 1 1284 ) 1285 using_optim_input = FullyShardedDataParallel._is_using_optim_input( 1286 optim_input, 1287 optim, 1288 ) 1289 else: 1290 using_optim_input = False 1291 assert optim_input is None and not rank0_only 1292 1293 use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ 1294 0 1295 ]._use_orig_params 1296 assert all( 1297 use_orig_params == m._use_orig_params 1298 for m in FullyShardedDataParallel.fsdp_modules(model) 1299 ), "Not all FSDP modules have the same _use_orig_params value" 1300 1301 return _optim_state_dict( 1302 model=model, 1303 optim=optim, 1304 optim_state_dict=optim_state_dict, 1305 optim_input=optim_input, 1306 rank0_only=rank0_only, 1307 shard_state=not full_state_dict, 1308 group=group, 1309 using_optim_input=using_optim_input, 1310 use_orig_params=use_orig_params, 1311 cpu_offload=cpu_offload, 1312 ) 1313 1314 @staticmethod 1315 def _optim_state_dict_to_load_impl( 1316 optim_state_dict: Dict[str, Any], 1317 model: torch.nn.Module, 1318 optim_input: Optional[ 1319 Union[ 1320 List[Dict[str, Any]], 1321 Iterable[torch.nn.Parameter], 1322 ] 1323 ] = None, 1324 optim: Optional[torch.optim.Optimizer] = None, 1325 full_state_dict: bool = True, 1326 rank0_only: bool = False, 1327 is_named_optimizer: bool = False, 1328 group: Optional[dist.ProcessGroup] = None, 1329 ) -> Dict[str, Any]: 1330 """ 1331 Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. 1332 1333 This is the internal API that is used by all the load optim_state_dict implementations. 1334 Given model, optim, and the saved optim_state_dict, this API adds the FSDP 1335 internal information and internal sharding to the optim_state_dict. 1336 """ 1337 if full_state_dict: 1338 FullyShardedDataParallel._warn_optim_input(optim_input) 1339 using_optim_input = FullyShardedDataParallel._is_using_optim_input( 1340 optim_input, 1341 optim, 1342 ) 1343 else: 1344 using_optim_input = False 1345 assert optim_input is None and not rank0_only 1346 1347 use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ 1348 0 1349 ]._use_orig_params 1350 assert all( 1351 use_orig_params == m._use_orig_params 1352 for m in FullyShardedDataParallel.fsdp_modules(model) 1353 ), "Not all FSDP modules have the same _use_orig_params value" 1354 1355 if rank0_only and dist.get_rank(group) > 0: 1356 optim_state_dict = {} 1357 sharded_osd = _flatten_optim_state_dict( 1358 optim_state_dict, 1359 model=model, 1360 use_orig_params=use_orig_params, 1361 optim=(optim if is_named_optimizer else None), 1362 rank0_only=rank0_only, 1363 group=group, 1364 ) 1365 return _rekey_sharded_optim_state_dict( 1366 sharded_osd, 1367 model=model, 1368 optim=optim, 1369 optim_input=optim_input, 1370 using_optim_input=using_optim_input, 1371 is_named_optimizer=is_named_optimizer, 1372 ) 1373 1374 @staticmethod 1375 def full_optim_state_dict( 1376 model: torch.nn.Module, 1377 optim: torch.optim.Optimizer, 1378 optim_input: Optional[ 1379 Union[ 1380 List[Dict[str, Any]], 1381 Iterable[torch.nn.Parameter], 1382 ] 1383 ] = None, 1384 rank0_only: bool = True, 1385 group: Optional[dist.ProcessGroup] = None, 1386 ) -> Dict[str, Any]: 1387 """Return the full optimizer state-dict. 1388 1389 Consolidates the full optimizer state on rank 0 and returns it 1390 as a :class:`dict` following the convention of 1391 :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` 1392 and ``"param_groups"``. The flattened parameters in ``FSDP`` modules 1393 contained in ``model`` are mapped back to their unflattened parameters. 1394 1395 This needs to be called on all ranks since it uses 1396 collective communications. However, if ``rank0_only=True``, then 1397 the state dict is only populated on rank 0, and all other ranks 1398 return an empty :class:`dict`. 1399 1400 Unlike ``torch.optim.Optimizer.state_dict()``, this method 1401 uses full parameter names as keys instead of parameter IDs. 1402 1403 Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors 1404 contained in the optimizer state dict are not cloned, so there may 1405 be aliasing surprises. For best practices, consider saving the 1406 returned optimizer state dict immediately, e.g. using 1407 ``torch.save()``. 1408 1409 Args: 1410 model (torch.nn.Module): Root module (which may or may not be a 1411 :class:`FullyShardedDataParallel` instance) whose parameters 1412 were passed into the optimizer ``optim``. 1413 optim (torch.optim.Optimizer): Optimizer for ``model`` 's 1414 parameters. 1415 optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): 1416 Input passed into the optimizer ``optim`` representing either a 1417 :class:`list` of parameter groups or an iterable of parameters; 1418 if ``None``, then this method assumes the input was 1419 ``model.parameters()``. This argument is deprecated, and there 1420 is no need to pass it in anymore. (Default: ``None``) 1421 rank0_only (bool): If ``True``, saves the populated :class:`dict` 1422 only on rank 0; if ``False``, saves it on all ranks. (Default: 1423 ``True``) 1424 group (dist.ProcessGroup): Model's process group or ``None`` if using 1425 the default process group. (Default: ``None``) 1426 1427 Returns: 1428 Dict[str, Any]: A :class:`dict` containing the optimizer state for 1429 ``model`` 's original unflattened parameters and including keys 1430 "state" and "param_groups" following the convention of 1431 :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, 1432 then nonzero ranks return an empty :class:`dict`. 1433 """ 1434 FullyShardedDataParallel._warn_legacy_optim_state_dict( 1435 "full_optim_state_dict", 1436 "optim_state_dict", 1437 stacklevel=2, 1438 ) 1439 return FullyShardedDataParallel._optim_state_dict_impl( 1440 model=model, 1441 optim=optim, 1442 optim_state_dict=optim.state_dict(), 1443 optim_input=optim_input, 1444 rank0_only=rank0_only, 1445 group=group, 1446 full_state_dict=True, 1447 _stacklevel=2, 1448 ) 1449 1450 @staticmethod 1451 def sharded_optim_state_dict( 1452 model: torch.nn.Module, 1453 optim: torch.optim.Optimizer, 1454 group: Optional[dist.ProcessGroup] = None, 1455 ) -> Dict[str, Any]: 1456 """Return the optimizer state-dict in its sharded form. 1457 1458 The API is similar to :meth:`full_optim_state_dict` but this API chunks 1459 all non-zero-dimension states to :class:`ShardedTensor` to save memory. 1460 This API should only be used when the model ``state_dict`` is derived 1461 with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. 1462 1463 For the detailed usage, refer to :meth:`full_optim_state_dict`. 1464 1465 .. warning:: The returned state dict contains ``ShardedTensor`` and 1466 cannot be directly used by the regular ``optim.load_state_dict``. 1467 """ 1468 FullyShardedDataParallel._warn_legacy_optim_state_dict( 1469 "sharded_optim_state_dict", 1470 "optim_state_dict", 1471 stacklevel=2, 1472 ) 1473 return FullyShardedDataParallel._optim_state_dict_impl( 1474 model=model, 1475 optim=optim, 1476 optim_state_dict=optim.state_dict(), 1477 optim_input=None, 1478 rank0_only=False, 1479 full_state_dict=False, 1480 group=group, 1481 _stacklevel=2, 1482 ) 1483 1484 @staticmethod 1485 def shard_full_optim_state_dict( 1486 full_optim_state_dict: Dict[str, Any], 1487 model: torch.nn.Module, 1488 optim_input: Optional[ 1489 Union[ 1490 List[Dict[str, Any]], 1491 Iterable[torch.nn.Parameter], 1492 ] 1493 ] = None, 1494 optim: Optional[torch.optim.Optimizer] = None, 1495 ) -> Dict[str, Any]: 1496 """Shard a full optimizer state-dict. 1497 1498 Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened 1499 parameters and restricts to only this rank's part of the optimizer state. 1500 The first argument should be the return value of :meth:`full_optim_state_dict`. 1501 1502 Example:: 1503 1504 >>> # xdoctest: +SKIP("undefined variables") 1505 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1506 >>> model, optim = ... 1507 >>> full_osd = FSDP.full_optim_state_dict(model, optim) 1508 >>> torch.save(full_osd, PATH) 1509 >>> # Define new model with possibly different world size 1510 >>> new_model, new_optim = ... 1511 >>> full_osd = torch.load(PATH) 1512 >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) 1513 >>> new_optim.load_state_dict(sharded_osd) 1514 1515 .. note:: Both :meth:`shard_full_optim_state_dict` and 1516 :meth:`scatter_full_optim_state_dict` may be used to get the 1517 sharded optimizer state dict to load. Assuming that the full 1518 optimizer state dict resides in CPU memory, the former requires 1519 each rank to have the full dict in CPU memory, where each rank 1520 individually shards the dict without any communication, while the 1521 latter requires only rank 0 to have the full dict in CPU memory, 1522 where rank 0 moves each shard to GPU memory (for NCCL) and 1523 communicates it to ranks appropriately. Hence, the former has 1524 higher aggregate CPU memory cost, while the latter has higher 1525 communication cost. 1526 1527 Args: 1528 full_optim_state_dict (Dict[str, Any]): Optimizer state dict 1529 corresponding to the unflattened parameters and holding the 1530 full non-sharded optimizer state. 1531 model (torch.nn.Module): Root module (which may or may not be a 1532 :class:`FullyShardedDataParallel` instance) whose parameters 1533 correspond to the optimizer state in ``full_optim_state_dict``. 1534 optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): 1535 Input passed into the optimizer representing either a 1536 :class:`list` of parameter groups or an iterable of parameters; 1537 if ``None``, then this method assumes the input was 1538 ``model.parameters()``. This argument is deprecated, and there 1539 is no need to pass it in anymore. (Default: ``None``) 1540 optim (Optional[torch.optim.Optimizer]): Optimizer that will load 1541 the state dict returned by this method. This is the preferred 1542 argument to use over ``optim_input``. (Default: ``None``) 1543 1544 Returns: 1545 Dict[str, Any]: The full optimizer state dict now remapped to 1546 flattened parameters instead of unflattened parameters and 1547 restricted to only include this rank's part of the optimizer state. 1548 """ 1549 FullyShardedDataParallel._warn_legacy_optim_state_dict( 1550 "shard_full_optim_state_dict", 1551 "optim_state_dict_to_load", 1552 stacklevel=2, 1553 ) 1554 return FullyShardedDataParallel._optim_state_dict_to_load_impl( 1555 optim_state_dict=full_optim_state_dict, 1556 model=model, 1557 optim_input=optim_input, 1558 optim=optim, 1559 full_state_dict=True, 1560 is_named_optimizer=False, 1561 ) 1562 1563 @staticmethod 1564 def flatten_sharded_optim_state_dict( 1565 sharded_optim_state_dict: Dict[str, Any], 1566 model: torch.nn.Module, 1567 optim: torch.optim.Optimizer, 1568 ) -> Dict[str, Any]: 1569 """Flatten a sharded optimizer state-dict. 1570 1571 The API is similar to :meth:`shard_full_optim_state_dict`. The only 1572 difference is that the input ``sharded_optim_state_dict`` should be 1573 returned from :meth:`sharded_optim_state_dict`. Therefore, there will 1574 be all-gather calls on each rank to gather ``ShardedTensor`` s. 1575 1576 Args: 1577 sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict 1578 corresponding to the unflattened parameters and holding the 1579 sharded optimizer state. 1580 model (torch.nn.Module): 1581 Refer to :meth:`shard_full_optim_state_dict`. 1582 optim (torch.optim.Optimizer): Optimizer for ``model`` 's 1583 parameters. 1584 1585 Returns: 1586 Refer to :meth:`shard_full_optim_state_dict`. 1587 """ 1588 FullyShardedDataParallel._warn_legacy_optim_state_dict( 1589 "flatten_sharded_optim_state_dict", 1590 "optim_state_dict_to_load", 1591 stacklevel=2, 1592 ) 1593 return FullyShardedDataParallel._optim_state_dict_to_load_impl( 1594 optim_state_dict=sharded_optim_state_dict, 1595 model=model, 1596 optim_input=None, 1597 optim=optim, 1598 full_state_dict=False, 1599 is_named_optimizer=False, 1600 ) 1601 1602 @staticmethod 1603 def scatter_full_optim_state_dict( 1604 full_optim_state_dict: Optional[Dict[str, Any]], 1605 model: torch.nn.Module, 1606 optim_input: Optional[ 1607 Union[ 1608 List[Dict[str, Any]], 1609 Iterable[torch.nn.Parameter], 1610 ] 1611 ] = None, 1612 optim: Optional[torch.optim.Optimizer] = None, 1613 group: Optional[Any] = None, 1614 ) -> Dict[str, Any]: 1615 """Scatter the full optimizer state dict from rank 0 to all other ranks. 1616 1617 Returns the sharded optimizer state dict on each rank. 1618 The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank 1619 0, the first argument should be the return value of 1620 :meth:`full_optim_state_dict`. 1621 1622 Example:: 1623 1624 >>> # xdoctest: +SKIP("undefined variables") 1625 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1626 >>> model, optim = ... 1627 >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 1628 >>> # Define new model with possibly different world size 1629 >>> new_model, new_optim, new_group = ... 1630 >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) 1631 >>> new_optim.load_state_dict(sharded_osd) 1632 1633 .. note:: Both :meth:`shard_full_optim_state_dict` and 1634 :meth:`scatter_full_optim_state_dict` may be used to get the 1635 sharded optimizer state dict to load. Assuming that the full 1636 optimizer state dict resides in CPU memory, the former requires 1637 each rank to have the full dict in CPU memory, where each rank 1638 individually shards the dict without any communication, while the 1639 latter requires only rank 0 to have the full dict in CPU memory, 1640 where rank 0 moves each shard to GPU memory (for NCCL) and 1641 communicates it to ranks appropriately. Hence, the former has 1642 higher aggregate CPU memory cost, while the latter has higher 1643 communication cost. 1644 1645 Args: 1646 full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state 1647 dict corresponding to the unflattened parameters and holding 1648 the full non-sharded optimizer state if on rank 0; the argument 1649 is ignored on nonzero ranks. 1650 model (torch.nn.Module): Root module (which may or may not be a 1651 :class:`FullyShardedDataParallel` instance) whose parameters 1652 correspond to the optimizer state in ``full_optim_state_dict``. 1653 optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): 1654 Input passed into the optimizer representing either a 1655 :class:`list` of parameter groups or an iterable of parameters; 1656 if ``None``, then this method assumes the input was 1657 ``model.parameters()``. This argument is deprecated, and there 1658 is no need to pass it in anymore. (Default: ``None``) 1659 optim (Optional[torch.optim.Optimizer]): Optimizer that will load 1660 the state dict returned by this method. This is the preferred 1661 argument to use over ``optim_input``. (Default: ``None``) 1662 group (dist.ProcessGroup): Model's process group or ``None`` if 1663 using the default process group. (Default: ``None``) 1664 1665 Returns: 1666 Dict[str, Any]: The full optimizer state dict now remapped to 1667 flattened parameters instead of unflattened parameters and 1668 restricted to only include this rank's part of the optimizer state. 1669 """ 1670 FullyShardedDataParallel._warn_legacy_optim_state_dict( 1671 "scatter_full_optim_state_dict", 1672 "optim_state_dict_to_load", 1673 stacklevel=2, 1674 ) 1675 return FullyShardedDataParallel._optim_state_dict_to_load_impl( 1676 optim_state_dict=full_optim_state_dict, 1677 model=model, 1678 optim_input=optim_input, 1679 optim=optim, 1680 full_state_dict=True, 1681 rank0_only=True, 1682 is_named_optimizer=False, 1683 group=group, 1684 ) 1685 1686 @staticmethod 1687 def rekey_optim_state_dict( 1688 optim_state_dict: Dict[str, Any], 1689 optim_state_key_type: OptimStateKeyType, 1690 model: torch.nn.Module, 1691 optim_input: Optional[ 1692 Union[ 1693 List[Dict[str, Any]], 1694 Iterable[torch.nn.Parameter], 1695 ] 1696 ] = None, 1697 optim: Optional[torch.optim.Optimizer] = None, 1698 ) -> Dict[str, Any]: 1699 """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. 1700 1701 This can be used to achieve compatibility between optimizer state dicts from models with FSDP 1702 instances and ones without. 1703 1704 To re-key an FSDP full optimizer state dict (i.e. from 1705 :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to 1706 a non-wrapped model:: 1707 1708 >>> # xdoctest: +SKIP("undefined variables") 1709 >>> wrapped_model, wrapped_optim = ... 1710 >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) 1711 >>> nonwrapped_model, nonwrapped_optim = ... 1712 >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) 1713 >>> nonwrapped_optim.load_state_dict(rekeyed_osd) 1714 1715 To re-key a normal optimizer state dict from a non-wrapped model to be 1716 loadable to a wrapped model:: 1717 1718 >>> # xdoctest: +SKIP("undefined variables") 1719 >>> nonwrapped_model, nonwrapped_optim = ... 1720 >>> osd = nonwrapped_optim.state_dict() 1721 >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) 1722 >>> wrapped_model, wrapped_optim = ... 1723 >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) 1724 >>> wrapped_optim.load_state_dict(sharded_osd) 1725 1726 Returns: 1727 Dict[str, Any]: The optimizer state dict re-keyed using the 1728 parameter keys specified by ``optim_state_key_type``. 1729 """ 1730 FullyShardedDataParallel._warn_optim_input(optim_input) 1731 using_optim_input = FullyShardedDataParallel._is_using_optim_input( 1732 optim_input, 1733 optim, 1734 ) 1735 assert optim_state_key_type in ( 1736 OptimStateKeyType.PARAM_NAME, 1737 OptimStateKeyType.PARAM_ID, 1738 ) 1739 osd = optim_state_dict # alias 1740 # Validate that the existing parameter keys are uniformly typed 1741 uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]] 1742 uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]] 1743 if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or ( 1744 any(uses_param_id_mask) and not all(uses_param_id_mask) 1745 ): 1746 error_msg = f"Invalid parameter keys: {osd['state'].keys()}" 1747 raise ValueError(error_msg) 1748 # Return directly if the existing key type matches the target key type 1749 if ( 1750 optim_state_key_type == OptimStateKeyType.PARAM_NAME 1751 and all(uses_param_name_mask) 1752 ) or ( 1753 optim_state_key_type == OptimStateKeyType.PARAM_ID 1754 and all(uses_param_id_mask) 1755 ): 1756 return osd 1757 # Otherwise, actually perform the re-keying 1758 new_osd = {} 1759 if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name 1760 param_id_to_param = ( 1761 _get_param_id_to_param_from_optim_input(model, optim_input) 1762 if using_optim_input 1763 else _get_param_key_to_param(optim) 1764 ) 1765 param_to_param_name = _get_param_to_fqn(model) 1766 param_id_to_param_name: List[str] = [ 1767 param_to_param_name[param] for param in param_id_to_param.values() 1768 ] 1769 new_osd["state"] = { 1770 param_id_to_param_name[param_id]: param_state 1771 for param_id, param_state in osd["state"].items() 1772 } 1773 new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) 1774 for param_group in new_osd["param_groups"]: 1775 param_group["params"] = sorted( 1776 [ 1777 param_id_to_param_name[param_id] 1778 for param_id in param_group["params"] 1779 ] 1780 ) 1781 return new_osd 1782 elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID 1783 param_name_to_param = _get_fqn_to_param(model) 1784 param_to_param_id = ( 1785 _get_param_to_param_id_from_optim_input(model, optim_input) 1786 if using_optim_input 1787 else _get_param_to_param_key(optim) 1788 ) 1789 # Because not all model parameters may be passed as the optimizer 1790 # input, we may need to drop some parameters from this mapping 1791 param_name_to_param_id = { 1792 param_name: param_to_param_id[param] 1793 for param_name, param in param_name_to_param.items() 1794 if param in param_to_param_id 1795 } 1796 new_osd["state"] = { 1797 param_name_to_param_id[param_name]: param_state 1798 for param_name, param_state in osd["state"].items() 1799 } 1800 new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) 1801 for param_group in new_osd["param_groups"]: 1802 param_group["params"] = sorted( 1803 [ 1804 param_name_to_param_id[param_name] 1805 for param_name in param_group["params"] 1806 ] 1807 ) 1808 return new_osd 1809 return new_osd # should never reach here 1810 1811 @staticmethod 1812 def optim_state_dict( 1813 model: torch.nn.Module, 1814 optim: torch.optim.Optimizer, 1815 optim_state_dict: Optional[Dict[str, Any]] = None, 1816 group: Optional[dist.ProcessGroup] = None, 1817 ) -> Dict[str, Any]: 1818 """ 1819 Transform the state-dict of an optimizer corresponding to a sharded model. 1820 1821 The given state-dict can be transformed to one of three types: 1822 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. 1823 1824 For full optimizer state_dict, all states are unflattened and not sharded. 1825 Rank0 only and CPU only can be specified via :meth:`state_dict_type` to 1826 avoid OOM. 1827 1828 For sharded optimizer state_dict, all states are unflattened but sharded. 1829 CPU only can be specified via :meth:`state_dict_type` to further save 1830 memory. 1831 1832 For local state_dict, no transformation will be performed. But a state 1833 will be converted from nn.Tensor to ShardedTensor to represent its sharding 1834 nature (this is not supported yet). 1835 1836 Example:: 1837 1838 >>> # xdoctest: +SKIP("undefined variables") 1839 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1840 >>> from torch.distributed.fsdp import StateDictType 1841 >>> from torch.distributed.fsdp import FullStateDictConfig 1842 >>> from torch.distributed.fsdp import FullOptimStateDictConfig 1843 >>> # Save a checkpoint 1844 >>> model, optim = ... 1845 >>> FSDP.set_state_dict_type( 1846 >>> model, 1847 >>> StateDictType.FULL_STATE_DICT, 1848 >>> FullStateDictConfig(rank0_only=False), 1849 >>> FullOptimStateDictConfig(rank0_only=False), 1850 >>> ) 1851 >>> state_dict = model.state_dict() 1852 >>> optim_state_dict = FSDP.optim_state_dict(model, optim) 1853 >>> save_a_checkpoint(state_dict, optim_state_dict) 1854 >>> # Load a checkpoint 1855 >>> model, optim = ... 1856 >>> state_dict, optim_state_dict = load_a_checkpoint() 1857 >>> FSDP.set_state_dict_type( 1858 >>> model, 1859 >>> StateDictType.FULL_STATE_DICT, 1860 >>> FullStateDictConfig(rank0_only=False), 1861 >>> FullOptimStateDictConfig(rank0_only=False), 1862 >>> ) 1863 >>> model.load_state_dict(state_dict) 1864 >>> optim_state_dict = FSDP.optim_state_dict_to_load( 1865 >>> model, optim, optim_state_dict 1866 >>> ) 1867 >>> optim.load_state_dict(optim_state_dict) 1868 1869 Args: 1870 model (torch.nn.Module): Root module (which may or may not be a 1871 :class:`FullyShardedDataParallel` instance) whose parameters 1872 were passed into the optimizer ``optim``. 1873 optim (torch.optim.Optimizer): Optimizer for ``model`` 's 1874 parameters. 1875 optim_state_dict (Dict[str, Any]): the target optimizer state_dict to 1876 transform. If the value is None, optim.state_dict() will be used. ( 1877 Default: ``None``) 1878 group (dist.ProcessGroup): Model's process group across which parameters 1879 are sharded or ``None`` if using the default process group. ( 1880 Default: ``None``) 1881 1882 Returns: 1883 Dict[str, Any]: A :class:`dict` containing the optimizer state for 1884 ``model``. The sharding of the optimizer state is based on 1885 ``state_dict_type``. 1886 """ 1887 state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) 1888 if optim_state_dict is None: 1889 optim_state_dict = optim.state_dict() 1890 return FullyShardedDataParallel._optim_state_dict_impl( 1891 model=model, 1892 optim=optim, 1893 optim_state_dict=optim_state_dict, 1894 optim_input=None, 1895 rank0_only=getattr( 1896 state_dict_settings.optim_state_dict_config, "rank0_only", False 1897 ), 1898 full_state_dict=state_dict_settings.state_dict_type 1899 == StateDictType.FULL_STATE_DICT, 1900 group=group, 1901 cpu_offload=getattr( 1902 state_dict_settings.optim_state_dict_config, "offload_to_cpu", True 1903 ), 1904 _stacklevel=2, 1905 ) 1906 1907 @staticmethod 1908 def optim_state_dict_to_load( 1909 model: torch.nn.Module, 1910 optim: torch.optim.Optimizer, 1911 optim_state_dict: Dict[str, Any], 1912 is_named_optimizer: bool = False, 1913 load_directly: bool = False, 1914 group: Optional[dist.ProcessGroup] = None, 1915 ) -> Dict[str, Any]: 1916 """ 1917 Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. 1918 1919 Given a ``optim_state_dict`` that is transformed through 1920 :meth:`optim_state_dict`, it gets converted to the flattened optimizer 1921 state_dict that can be loaded to ``optim`` which is the optimizer for 1922 ``model``. ``model`` must be sharded by FullyShardedDataParallel. 1923 1924 >>> # xdoctest: +SKIP("undefined variables") 1925 >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 1926 >>> from torch.distributed.fsdp import StateDictType 1927 >>> from torch.distributed.fsdp import FullStateDictConfig 1928 >>> from torch.distributed.fsdp import FullOptimStateDictConfig 1929 >>> # Save a checkpoint 1930 >>> model, optim = ... 1931 >>> FSDP.set_state_dict_type( 1932 >>> model, 1933 >>> StateDictType.FULL_STATE_DICT, 1934 >>> FullStateDictConfig(rank0_only=False), 1935 >>> FullOptimStateDictConfig(rank0_only=False), 1936 >>> ) 1937 >>> state_dict = model.state_dict() 1938 >>> original_osd = optim.state_dict() 1939 >>> optim_state_dict = FSDP.optim_state_dict( 1940 >>> model, 1941 >>> optim, 1942 >>> optim_state_dict=original_osd 1943 >>> ) 1944 >>> save_a_checkpoint(state_dict, optim_state_dict) 1945 >>> # Load a checkpoint 1946 >>> model, optim = ... 1947 >>> state_dict, optim_state_dict = load_a_checkpoint() 1948 >>> FSDP.set_state_dict_type( 1949 >>> model, 1950 >>> StateDictType.FULL_STATE_DICT, 1951 >>> FullStateDictConfig(rank0_only=False), 1952 >>> FullOptimStateDictConfig(rank0_only=False), 1953 >>> ) 1954 >>> model.load_state_dict(state_dict) 1955 >>> optim_state_dict = FSDP.optim_state_dict_to_load( 1956 >>> model, optim, optim_state_dict 1957 >>> ) 1958 >>> optim.load_state_dict(optim_state_dict) 1959 1960 Args: 1961 model (torch.nn.Module): Root module (which may or may not be a 1962 :class:`FullyShardedDataParallel` instance) whose parameters 1963 were passed into the optimizer ``optim``. 1964 optim (torch.optim.Optimizer): Optimizer for ``model`` 's 1965 parameters. 1966 optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. 1967 is_named_optimizer (bool): Is this optimizer a NamedOptimizer or 1968 KeyedOptimizer. Only set to True if ``optim`` is TorchRec's 1969 KeyedOptimizer or torch.distributed's NamedOptimizer. 1970 load_directly (bool): If this is set to True, this API will also 1971 call optim.load_state_dict(result) before returning the result. 1972 Otherwise, users are responsible to call ``optim.load_state_dict()`` 1973 (Default: ``False``) 1974 group (dist.ProcessGroup): Model's process group across which parameters 1975 are sharded or ``None`` if using the default process group. ( 1976 Default: ``None``) 1977 """ 1978 state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) 1979 result = FullyShardedDataParallel._optim_state_dict_to_load_impl( 1980 optim_state_dict=optim_state_dict, 1981 model=model, 1982 optim_input=None, 1983 optim=optim, 1984 full_state_dict=( 1985 state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT 1986 ), 1987 rank0_only=getattr( 1988 state_dict_settings.optim_state_dict_config, "rank0_only", False 1989 ), 1990 is_named_optimizer=is_named_optimizer, 1991 group=group, 1992 ) 1993 if load_directly: 1994 optim.load_state_dict(result) 1995 return result 1996 1997 def register_comm_hook(self, state: object, hook: callable): 1998 """Register a communication hook. 1999 2000 This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates 2001 gradients across multiple workers. 2002 This hook can be used to implement several algorithms like 2003 `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression 2004 which involve different communication strategies for 2005 parameter syncs while training with :class:`FullyShardedDataParallel`. 2006 2007 .. warning :: 2008 FSDP communication hook should be registered before running an initial forward pass 2009 and only once. 2010 2011 Args: 2012 state (object): Passed to the hook to maintain any state information during the training process. 2013 Examples include error feedback in gradient compression, 2014 peers to communicate with next in `GossipGrad <https://arxiv.org/abs/1803.05880>`_, etc. 2015 It is locally stored by each worker 2016 and shared by all the gradient tensors on the worker. 2017 hook (Callable): Callable, which has one of the following signatures: 2018 1) ``hook: Callable[torch.Tensor] -> None``: 2019 This function takes in a Python tensor, which represents 2020 the full, flattened, unsharded gradient with respect to all variables 2021 corresponding to the model this FSDP unit is wrapping 2022 (that are not wrapped by other FSDP sub-units). 2023 It then performs all necessary processing and returns ``None``; 2024 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: 2025 This function takes in two Python tensors, the first one represents 2026 the full, flattened, unsharded gradient with respect to all variables 2027 corresponding to the model this FSDP unit is wrapping 2028 (that are not wrapped by other FSDP sub-units). The latter 2029 represents a pre-sized tensor to store a chunk of a sharded gradient after 2030 reduction. 2031 In both cases, callable performs all necessary processing and returns ``None``. 2032 Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. 2033 Callables with signature 2 are expected to handle gradient communication for sharded cases. 2034 2035 """ 2036 if not self.check_is_root(): 2037 raise AssertionError( 2038 "register_comm_hook can only be called on a root instance." 2039 ) 2040 for fsdp_state in traversal_utils._get_fsdp_states(self): 2041 if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: 2042 raise AssertionError( 2043 f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" 2044 ) 2045 if fsdp_state._comm_hook is not None: 2046 raise AssertionError("A communication hook is already registered") 2047 if not callable(hook): 2048 raise ValueError( 2049 f"The communication hook must be callable but got {hook}" 2050 ) 2051 fsdp_state._comm_hook = hook 2052 fsdp_state._comm_hook_state = state 2053 2054 def _unshard(self, async_op: bool = False): 2055 class UnshardHandle: 2056 def __init__( 2057 self, 2058 flat_param_handle: Optional[FlatParamHandle], 2059 unshard_event: torch.Event, 2060 ): 2061 self._flat_param_handle = flat_param_handle 2062 self._unshard_event = unshard_event 2063 2064 def wait(self): 2065 if self._flat_param_handle is not None: 2066 current_stream = ( 2067 self._flat_param_handle._device_handle.current_stream() 2068 ) 2069 current_stream.wait_event(self._unshard_event) 2070 self._flat_param_handle = None 2071 2072 if self._handle: 2073 with self._use_training_state( 2074 TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD 2075 ): 2076 _unshard( 2077 self, self._handle, self._unshard_stream, self._pre_unshard_stream 2078 ) 2079 self._unshard_event = self._unshard_stream.record_event() 2080 self._handle._prefetched = True 2081 unshard_handle = UnshardHandle(self._handle, self._unshard_stream) 2082 if async_op: 2083 return unshard_handle 2084 unshard_handle.wait() 2085 return None 2086 2087 def _wait_unshard_streams_on_current_stream(self): 2088 _wait_for_computation_stream( 2089 self._device_handle.current_stream(), 2090 self._unshard_stream, 2091 self._pre_unshard_stream, 2092 ) 2093 2094 @contextlib.contextmanager 2095 def _use_training_state( 2096 self, training_state: TrainingState, handle_training_state: HandleTrainingState 2097 ): 2098 prev_training_state = self.training_state 2099 self.training_state = training_state 2100 if self._handle: 2101 prev_handle_training_state = self._handle._training_state 2102 self._handle._training_state = handle_training_state 2103 try: 2104 yield 2105 finally: 2106 self.training_state = prev_training_state 2107 if self._handle: 2108 self._handle._training_state = prev_handle_training_state 2109 2110 2111def _get_grad_norm( 2112 params: Iterable[nn.Parameter], 2113 norm_type: float, 2114 zero: torch.Tensor, 2115 device: torch.device, 2116) -> torch.Tensor: 2117 """ 2118 Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector. 2119 2120 The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream 2121 use of this return value is a reduction across ranks. 2122 """ 2123 params_with_grad = [param for param in params if param.grad is not None] 2124 if len(params_with_grad) == 0: 2125 # Reuse a tensor for zero to avoid a GPU sync 2126 return zero 2127 grads = [param.grad for param in params_with_grad] 2128 grad_dtypes = {grad.dtype for grad in grads} 2129 if len(grad_dtypes) != 1: 2130 raise ValueError( 2131 f"Requires uniform dtype across all gradients but got {grad_dtypes}" 2132 ) 2133 # Compute the gradient norm in FP32, where we treat the gradients as a 2134 # single vector 2135 grad_norm = torch.linalg.vector_norm( 2136 torch.stack( 2137 [ 2138 torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32) 2139 for grad in grads 2140 ], 2141 ), 2142 norm_type, 2143 dtype=torch.float32, 2144 ) 2145 return grad_norm.to(device=device) 2146 2147 2148def _get_param_to_fqn( 2149 model: torch.nn.Module, 2150) -> Dict[torch.nn.Parameter, str]: 2151 """ 2152 Construct a mapping from parameters to their parameter names. 2153 2154 The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which 2155 means that none of the parameters should be ``FlatParameter`` s. As a 2156 result, compared to :meth:`_get_param_to_fqns`, the mapped 2157 values may be flattened from singleton :class:`list` s to the contained 2158 names themselves. 2159 2160 Args: 2161 model (torch.nn.Module): Root module, which should not contain any 2162 :class:`FullyShardedDataParallel` instances. 2163 """ 2164 param_to_param_names = _get_param_to_fqns(model) 2165 for param_names in param_to_param_names.values(): 2166 assert ( 2167 len(param_names) > 0 2168 ), "`_get_param_to_fqns()` should not construct empty lists" 2169 if len(param_names) > 1: 2170 raise RuntimeError( 2171 "Each parameter should only map to one parameter name but got " 2172 f"{len(param_names)}: {param_names}" 2173 ) 2174 param_to_param_name = { 2175 param: param_names[0] for param, param_names in param_to_param_names.items() 2176 } 2177 return param_to_param_name 2178 2179 2180def _get_fqn_to_param( 2181 model: torch.nn.Module, 2182) -> Dict[str, torch.nn.Parameter]: 2183 """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" 2184 param_to_param_name = _get_param_to_fqn(model) 2185 return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) 2186