1# mypy: allow-untyped-defs 2import functools 3import logging 4from enum import auto, Enum 5from typing import Any, Callable, Dict, List, no_type_check, Optional, Set, Tuple 6 7import torch 8import torch.distributed as dist 9import torch.distributed.fsdp._traversal_utils as traversal_utils 10import torch.nn as nn 11import torch.nn.functional as F 12from torch.autograd import Variable 13from torch.autograd.graph import register_multi_grad_hook 14from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS 15from torch.distributed.fsdp._common_utils import ( 16 _assert_in_training_states, 17 _FSDPState, 18 _get_module_fsdp_state, 19 _is_composable, 20 _log_post_backward_hook, 21 _no_dispatch_record_stream, 22 clean_tensor_name, 23 TrainingState, 24) 25from torch.distributed.fsdp._flat_param import ( 26 FlatParameter, 27 FlatParamHandle, 28 HandleShardingStrategy, 29 HandleTrainingState, 30 RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES, 31) 32from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES 33from torch.distributed.fsdp.api import BackwardPrefetch 34from torch.distributed.utils import ( 35 _apply_to_tensors, 36 _cast_forward_inputs, 37 _p_assert, 38 _to_kwargs, 39) 40from torch.utils import _pytree as pytree 41 42 43logger = logging.getLogger(__name__) 44 45# Do not include "process_group" to enable hybrid shard and MoE cases 46HOMOGENEOUS_ATTR_NAMES = ( 47 "_use_orig_params", 48 "limit_all_gathers", 49 "_use_full_prec_in_eval", 50) 51 52 53class _PrefetchMode(Enum): 54 BACKWARD = auto() 55 FORWARD = auto() 56 57 58def _get_fsdp_root_states_with_modules( 59 module: nn.Module, 60) -> Tuple[List[_FSDPState], List[nn.Module]]: 61 """ 62 Returns a tuple containing: 63 1. A list of the root ``_FSDPState`` instances in the module tree rooted at 64 ``module`` without any duplicates and following the ``module.modules()`` 65 traversal order (which is assumed to be depth-first). 66 2. A corresponding list of the root modules owning the states in the first 67 list. 68 69 This is similar to :func:`_get_fsdp_states_with_modules` except that we 70 must call :func:`_is_fsdp_root` to force a lazy initialization to determine 71 the FSDP root in case lazy initialization has not yet happened. 72 """ 73 fsdp_root_states: List[_FSDPState] = [] 74 fsdp_root_modules: List[nn.Module] = [] 75 visited_fsdp_states: Set[_FSDPState] = set() 76 # NOTE: This function assumes that `module.modules()` proceeds top-down. 77 for submodule in module.modules(): 78 optional_state = _get_module_fsdp_state(submodule) 79 if ( 80 optional_state is not None 81 and optional_state not in visited_fsdp_states 82 and _is_fsdp_root(optional_state, submodule) 83 ): 84 visited_fsdp_states.add(optional_state) 85 fsdp_root_states.append(optional_state) 86 fsdp_root_modules.append(submodule) 87 return fsdp_root_states, fsdp_root_modules 88 89 90def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]: 91 """See :func:`_get_fsdp_root_states_with_modules`.""" 92 fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module) 93 return fsdp_root_states 94 95 96def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool: 97 """ 98 Returns if ``state`` corresponds to that of an FSDP root. 99 100 For the wrapper code path, ``state`` and ``module`` should be the same. For 101 the non-wrapper code path, ``state`` should be ``module`` 's state. 102 """ 103 # Force a lazy initialization to determine the FSDP root 104 _lazy_init(state, module) 105 assert state._is_root is not None # mypy 106 return state._is_root 107 108 109@no_type_check 110def _lazy_init( 111 state: _FSDPState, 112 root_module: nn.Module, 113) -> _FSDPState: 114 """ 115 Performs initialization lazily, typically right before the first forward 116 pass. The laziness is needed to ensure that the parameter device/dtype and 117 the FSDP hierarchy have finalized. This method's actual logic only runs on 118 the root FSDP instance, which performs initialization for all non-root FSDP 119 instances to avoid partial initialization. 120 121 For the non-composable code path, ``state`` and ``root_module`` should be 122 the same, namely the FSDP instance itself. 123 """ 124 if state._is_root is not None: 125 return # no-op: already lazily initialized 126 if not state._device_handle.is_available(): 127 # Allow the FSDP constructor to run even without CUDA but check this 128 # once we start real execution 129 raise RuntimeError("FSDP does not support CPU only execution") 130 # The following logic is only run on the root FSDP instance since it will 131 # set `_is_root=False` for the non-root instances 132 state._is_root = True 133 _assert_in_training_states(state, [TrainingState.IDLE]) 134 _check_flat_params_on_expected_device(state, root_module) 135 state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module) 136 _init_streams(state) 137 buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module) 138 _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device) 139 state._exec_order_data.init(state, root_module, state.process_group) 140 _share_state_and_init_handle_attrs(state, root_module) 141 return state 142 143 144def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module): 145 """ 146 Checks that all ``FlatParameter``s in ``module`` 's tree managed by 147 ``state`` are on the expected device for *lazy initialization*. 148 """ 149 cpu_device = torch.device("cpu") 150 for handle in traversal_utils._get_fsdp_handles(module): 151 if ( 152 not handle._offload_params 153 and handle.flat_param.device != state.compute_device 154 ): 155 raise RuntimeError( 156 "An FSDP-managed module unexpectedly has parameters on " 157 f"{handle.flat_param.device}. Make sure to move the module to " 158 f"{state.compute_device} before training." 159 ) 160 elif handle._offload_params and handle.flat_param.device != cpu_device: 161 raise RuntimeError( 162 "An FSDP-managed module with parameter CPU offloading enabled " 163 f"has parameters on {handle.flat_param.device}. Make sure to " 164 f"not move the module from CPU when offloading parameters." 165 ) 166 167 168@no_type_check 169def _share_state_and_init_handle_attrs( 170 root_state: _FSDPState, 171 root_module: nn.Module, 172) -> None: 173 """ 174 Shares data structure state from the ``root_state`` to all FSDP states in 175 ``root_module`` 's module tree, and initializes handle attributes. These 176 are done together to require a single loop over the states. 177 """ 178 handle = root_state._handle 179 if handle: 180 handle.init_flat_param_attributes() 181 attr_name_to_values: Dict[str, Set[Any]] = {} 182 for attr_name in HOMOGENEOUS_ATTR_NAMES: 183 attr_name_to_values[attr_name] = set() 184 root_state._all_handles = root_state._exec_order_data.all_handles # share reference 185 # Update _has_optim_in_backward for each handle. 186 for handle in root_state._all_handles: 187 flat_param = handle.flat_param 188 if hasattr(flat_param, "_in_backward_optimizers"): 189 raise RuntimeError( 190 "FSDP optimizer in backward only supported with use_orig_params=True!" 191 ) 192 handle._has_optim_in_backward = flat_param._params is not None and any( 193 hasattr(param, "_in_backward_optimizers") for param in flat_param._params 194 ) 195 if handle._has_optim_in_backward: 196 torch._C._log_api_usage_once("fsdp.optimizer_in_backward") 197 for fsdp_state in root_state._all_fsdp_states: 198 for attr_name in HOMOGENEOUS_ATTR_NAMES: 199 _p_assert( 200 hasattr(fsdp_state, attr_name), 201 f"FSDP state missing attribute {attr_name}", 202 ) 203 attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) 204 if fsdp_state is root_state: 205 continue 206 # Relax the assert for non-root FSDP instances in case the nested 207 # initialized module is wrapped again in FSDP later (e.g. after 208 # training to run inference) 209 _p_assert( 210 fsdp_state._is_root is None or not fsdp_state._is_root, 211 "Non-root FSDP instance's `_is_root` should not have been " 212 "set yet or should have been set to `False`", 213 ) 214 fsdp_state._is_root = False 215 fsdp_state._unshard_stream = root_state._unshard_stream 216 fsdp_state._post_backward_stream = root_state._post_backward_stream 217 fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream 218 fsdp_state._all_reduce_stream = root_state._all_reduce_stream 219 fsdp_state._default_stream = root_state._default_stream 220 fsdp_state._exec_order_data = root_state._exec_order_data 221 fsdp_state._free_event_queue = root_state._free_event_queue 222 if fsdp_state._fsdp_extension is not None: 223 fsdp_state._fsdp_extension.compute_stream = root_state._default_stream 224 handle = fsdp_state._handle 225 if handle: 226 handle.init_flat_param_attributes() 227 for attr_name, attr_values in attr_name_to_values.items(): 228 if len(attr_values) != 1: 229 raise ValueError( 230 f"Expects one homogeneous value for {attr_name} but got {attr_values}" 231 ) 232 233 234@no_type_check 235def _init_streams( 236 state: _FSDPState, 237) -> None: 238 """ 239 Initializes CUDA streams for overlapping communication, computation, and 240 data transfers. The streams should be shared across FSDP instances. 241 """ 242 assert state._is_root 243 assert state._device_handle.is_available() 244 uses_hybrid_sharding = any( 245 fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES 246 for fsdp_state in state._all_fsdp_states 247 ) 248 # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and 249 # preserve the default priority of 0 otherwise 250 high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0 251 # Default stream for computation 252 state._default_stream = state._device_handle.current_stream() 253 if state._fsdp_extension is not None: 254 # set the compute stream to the FSDP extension 255 state._fsdp_extension.compute_stream = state._default_stream 256 257 # Stream for unshard logic, including allocating the all-gather destination 258 # tensors and the all-gathers themselves 259 state._unshard_stream = state._device_handle.Stream(priority=high_priority) 260 # Stream for overlapping gradient reduction with the backward pass gradient 261 # computation 262 state._post_backward_stream = state._device_handle.Stream(priority=high_priority) 263 # Stream for pre-unshard logic, namely allocations and writes for CPU 264 # offloading (H2D copy) and mixed precision (low precision cast) 265 state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority) 266 # Stream to run HSDP's all-reduce as async (if using HSDP) 267 state._all_reduce_stream = ( 268 state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream 269 ) 270 271 272@no_type_check 273def _unshard( 274 state: _FSDPState, 275 handle: FlatParamHandle, 276 unshard_stream: torch.Stream, 277 pre_unshard_stream: torch.Stream, 278) -> None: 279 """ 280 Unshards the handles in ``handles``. If the handles are in 281 :meth:`summon_full_params` and are using mixed precision, then they are 282 forced to full precision. 283 284 Postcondition: handle's ``FlatParameter`` 's data is the padded 285 unsharded flat parameter on the compute device. 286 """ 287 if not handle: 288 return 289 with state._device_handle.stream(pre_unshard_stream): 290 ran_pre_unshard = handle.pre_unshard() 291 if ran_pre_unshard: 292 unshard_stream.wait_stream(pre_unshard_stream) 293 if state.limit_all_gathers: 294 event = state._free_event_queue.dequeue_if_needed() 295 if event: 296 with torch.profiler.record_function( 297 "FullyShardedDataParallel.rate_limiter" 298 ): 299 event.synchronize() 300 with state._device_handle.stream(unshard_stream): 301 handle.unshard() 302 handle.post_unshard() 303 304 305@no_type_check 306def _reshard( 307 state: _FSDPState, 308 handle: FlatParamHandle, 309 free_unsharded_flat_param: bool, 310): 311 """ 312 Reshards the handle. ``free_unsharded_flat_param`` indicates whether to 313 free the handle's padded unsharded flat parameter. 314 """ 315 handle.reshard(free_unsharded_flat_param) 316 if state.limit_all_gathers and free_unsharded_flat_param: 317 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 318 # We don't run a even queue for freeing under torch compile atm 319 # But maybe we need to? TODO(voz): Look into this 320 free_event = state._device_handle.Event() 321 free_event.record() 322 state._free_event_queue.enqueue(free_event) 323 handle.post_reshard() 324 # Flat parameter freed or not, we always have to "unshard" the parameter 325 # upon next access to get its shape correct. 326 handle._prefetched = False 327 328 329def _unshard_grads( 330 handle: Optional[FlatParamHandle], 331) -> None: 332 if handle: 333 handle.unshard_grad() 334 335 336def _reshard_grads( 337 handle: Optional[FlatParamHandle], 338) -> None: 339 if handle: 340 handle.reshard_grad() 341 342 343@no_type_check 344def _pre_forward( 345 state: _FSDPState, 346 handle: Optional[FlatParamHandle], 347 unshard_fn: Callable, 348 module: nn.Module, 349 args: Tuple[Any, ...], 350 kwargs: Dict[str, Any], 351) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: 352 """ 353 Runs the pre-forward logic. This includes an opportunity to unshard 354 currently sharded parameters such as those for the current forward and 355 registering post-backward hooks for these current parameters. This function 356 also converts forward ``args`` and ``kwargs`` to the given precision. 357 358 Args: 359 handles (List[FlatParamHandle]): Handles giving the parameters used in 360 the current forward. 361 unshard_fn (Optional[Callable]): A callable to unshard any currently 362 sharded parameters or ``None`` to not do any unsharding. 363 module (nn.Module): Module whose forward this method runs right before; 364 expected by the hook signature. 365 args (Tuple[Any, ...]): Module forward ``args``. 366 kwargs (Dict[str, Any]): Module forward ``kwargs``. 367 """ 368 with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"): 369 # For `fully_shard` + `checkpoint`, skip pre-forward logic in the 370 # recomputed forward 371 if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: 372 # For both checkpoint implementations, we do not need to re-cast 373 # inputs here since they will be checkpointed in the low precision 374 # either by AC or normally by autograd as long as the AC region is 375 # nested within FSDP 376 return args, kwargs 377 state.training_state = TrainingState.FORWARD_BACKWARD 378 state._exec_order_data.record_pre_forward(handle, module.training) 379 if handle: 380 handle._training_state = HandleTrainingState.FORWARD 381 if unshard_fn is not None: 382 unshard_fn(state, handle) 383 # Register post-backward hooks to reshard the parameters and reduce-scatter 384 # their gradients. They must be re-registered every forward pass in case 385 # the `grad_fn` is mutated. 386 _register_post_backward_hook(state, handle) 387 # We have to reallocate the _cpu_grad if optimizer overlap 388 # set the grad to None in the backward pass. 389 if handle and handle._offload_params and handle.flat_param._cpu_grad is None: 390 handle.flat_param._cpu_grad = torch.zeros_like( 391 handle.flat_param._local_shard, device=torch.device("cpu") 392 ).pin_memory(device=state.compute_device) 393 394 should_cast_forward_inputs = ( 395 state._handle and not state._handle._force_full_precision 396 ) 397 398 if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs: 399 # Recursively convert args and kwargs to specified precision. 400 input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype 401 args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) 402 _register_post_backward_reshard_only_hook(state, handle, args, kwargs) 403 return args, kwargs 404 405 406@no_type_check 407def _pre_forward_unshard( 408 state: _FSDPState, 409 handle: Optional[FlatParamHandle], 410) -> None: 411 """Unshards parameters in the pre-forward.""" 412 if not handle: 413 return 414 # If the handles have been prefetched, then there is no need to call 415 # `_unshard()` again 416 if not handle._prefetched: 417 _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) 418 handle._needs_pre_forward_unshard = False 419 # Don't wait during trace 420 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 421 current_stream = state._device_handle.current_stream() 422 if state._unshard_event is not None: 423 current_stream.wait_event(state._unshard_event) 424 state._unshard_event = None 425 else: 426 current_stream.wait_stream(state._unshard_stream) 427 with torch.profiler.record_function( 428 "FullyShardedDataParallel._pre_forward_prefetch" 429 ): 430 _prefetch_handle(state, handle, _PrefetchMode.FORWARD) 431 432 433@no_type_check 434def _post_forward( 435 state: _FSDPState, 436 handle: Optional[FlatParamHandle], 437 reshard_fn: Callable, 438 module: nn.Module, 439 input: Any, 440 output: Any, 441) -> Any: 442 """ 443 Runs the post-forward logic. This includes an opportunity to reshard 444 currently unsharded parameters such as those used in the current forward 445 and registering pre-backward hooks on the forward outputs. 446 447 Args: 448 handles (List[FlatParamHandle]): Handles giving the parameters used in 449 the current forward. 450 reshard_fn (Optional[Callable]): A callable to reshard any currently 451 unsharded parameters (e.g. from the current forward) or ``None`` to 452 not do any resharding. 453 module (nn.Module): Module whose forward just ran, which should be a 454 fully sharded module (see [Note: Fully Sharded Module]); expected 455 by the hook signature. 456 input (Any): Unused; expected by the hook signature. 457 output (Any): Forward pass output; pre-backward hooks are registered on 458 the tensors that require gradients in this output. 459 460 Postcondition: Each ``FlatParameter`` 's data points to the sharded flat 461 parameter. 462 """ 463 with torch.profiler.record_function("FullyShardedDataParallel._post_forward"): 464 # For `fully_shard` + `checkpoint`, skip post-forward logic in the 465 # recomputed forward 466 if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: 467 return output 468 469 state._exec_order_data.record_post_forward(handle) 470 if reshard_fn is not None: 471 reshard_fn(state, handle) 472 # Register pre-backward hooks to unshard the flat parameters for the 473 # gradient computation (if needed) 474 output = _register_pre_backward_hooks(state, module, output, handle) 475 state.training_state = TrainingState.IDLE 476 if handle: 477 handle._training_state = HandleTrainingState.IDLE 478 return output 479 480 481@no_type_check 482def _post_forward_reshard( 483 state: _FSDPState, 484 handle: FlatParamHandle, 485) -> None: 486 """Reshards parameters in the post-forward.""" 487 if not handle: 488 return 489 # Do not free the root's parameters in the post-forward for `FULL_SHARD` 490 # with the intention that they are immediately used for backward 491 # computation (though this may not be true) 492 free_unsharded_flat_param = ( 493 not state._is_root 494 and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES 495 ) 496 _reshard(state, handle, free_unsharded_flat_param) 497 498 499@no_type_check 500def _root_pre_forward( 501 state: _FSDPState, 502 module: nn.Module, 503 args, 504 kwargs, 505) -> None: 506 """ 507 Runs pre-forward logic specific to the root FSDP instance, which should run 508 before any individual module's pre-forward. This starts with an attempt at 509 lazy initialization (which only runs non-vacuously once). Otherwise, if 510 this is called on a non-root FSDP instance, then it returns directly. 511 512 Args: 513 module (nn.Module): Module for which this logic tries to run. It may or 514 may not be the root. If not, then this method does not do anything. 515 """ 516 with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"): 517 _lazy_init(state, module) 518 _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") 519 if not state._is_root: 520 # Always cast forward inputs in the root of this local FSDP unit for mixed 521 # precision, as this is where mixed precision could be configed. 522 # This is more useful for auto wrapping that is recommended in composable path. 523 # For manual wrapping, cast forward inputs on each local FSDP unit root will 524 # increase some overhead, so not turned on for model wrapper path right now where 525 # manual wrapping is more broadly used. 526 if _is_composable(state): 527 return _root_cast_forward_input(state, module, args, kwargs) 528 return args, kwargs 529 530 # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers 531 # are in full precision and if we should cast them back to lower precision, which happens when 532 # exiting eval() mode. 533 handle = state._handle 534 if handle: 535 should_cast_buffers_to_full_prec = handle._force_full_precision 536 else: 537 should_cast_buffers_to_full_prec = True 538 539 if should_cast_buffers_to_full_prec: 540 _cast_buffers_to_dtype_and_device( 541 buffers=dict(module.named_buffers()).values(), 542 buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()), 543 device=state.compute_device, 544 ) 545 # This flag is only set when we cast buffers to full precision, to avoid the 546 # CPU overhead that can stem from retrieving all buffers and their types in the 547 # following else branch. 548 state._needs_buffer_dtype_restore_check = True 549 elif getattr(state, "_needs_buffer_dtype_restore_check", False): 550 # Check if buffers are in full precision and we need to cast them 551 # back down. 552 ( 553 buffers, 554 buffer_dtypes_for_computation, 555 ) = _get_buffers_and_dtypes_for_computation(state, module) 556 if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0: 557 if any( 558 buffer.dtype != buffer_dtype_for_computation 559 for buffer, buffer_dtype_for_computation in zip( 560 buffers, buffer_dtypes_for_computation 561 ) 562 ): 563 # Assume we have to cast everything if there is one mismatch 564 _cast_buffers_to_dtype_and_device( 565 buffers, buffer_dtypes_for_computation, state.compute_device 566 ) 567 # We don't have to check this again until we cast buffers to full precision again. 568 state._needs_buffer_dtype_restore_check = False 569 570 if state.forward_prefetch: 571 handles = [] 572 for fsdp_state in state._all_fsdp_states: 573 if fsdp_state._handle: 574 handles.append(fsdp_state._handle) 575 for handle in handles: 576 handle._needs_pre_forward_unshard = True 577 handle._prefetched = False 578 _wait_for_computation_stream( 579 state._device_handle.current_stream(), 580 state._unshard_stream, 581 state._pre_unshard_stream, 582 ) 583 _reset_flat_param_grad_info_if_needed(state._all_handles) 584 585 # Prepares the forward inputs by moving them to ``compute_device`` 586 # TODO: Do not use the side stream for tensor copies for now; investigate 587 # the perf with/without it. 588 with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"): 589 args_tuple, kwargs_tuple = _to_kwargs( 590 args, kwargs, state.compute_device, False 591 ) 592 args = args_tuple[0] 593 kwargs = kwargs_tuple[0] 594 595 return _root_cast_forward_input(state, module, args, kwargs) 596 597 598@no_type_check 599def _root_cast_forward_input( 600 state: _FSDPState, module: torch.nn.Module, args, kwargs 601) -> Tuple[Any, Any]: 602 if state._handle: 603 force_full_precision = not state._handle._force_full_precision 604 else: 605 force_full_precision = True 606 607 should_cast_forward_inputs = ( 608 (module.training or not state._use_full_prec_in_eval) and force_full_precision 609 ) and state.mixed_precision.cast_root_forward_inputs 610 611 if should_cast_forward_inputs: 612 input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype 613 args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) 614 615 return args, kwargs 616 617 618@no_type_check 619def _pre_backward_hook( 620 state: _FSDPState, 621 module: nn.Module, 622 handle: FlatParamHandle, 623 grad, 624 *unused: Any, 625) -> Any: 626 """ 627 Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation. 628 629 Args: 630 module (nn.Module): Fully sharded module (see [Note: Fully Sharded 631 Module]). 632 """ 633 # Only run the pre-backward hook once per group of handles involved in the 634 # same module forward computation 635 if ( 636 handle 637 and hasattr(handle, "_ran_pre_backward_hook") 638 and handle._ran_pre_backward_hook 639 ): 640 return grad 641 642 with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"): 643 # Queue the post-backward callback once for the root FSDP instance to 644 # attach it to the outermost backward graph task so that it is called 645 # after all backward calls complete 646 if state._is_root and not state._post_backward_callback_queued: 647 _register_post_backward_final_callback(state, module) 648 _reset_flat_param_grad_info_if_needed(state._all_handles) 649 elif handle: 650 allowed_states = [TrainingState.IDLE] 651 if _is_composable(state): 652 allowed_states.append(TrainingState.FORWARD_BACKWARD) 653 _assert_in_training_states(state, allowed_states) 654 state.training_state = TrainingState.FORWARD_BACKWARD 655 # Queueing the post-backward callback is the only logic that is not 656 # per-handle in the pre-backward hook, so we can return early here if 657 # there are no handles. 658 if not handle: 659 return grad 660 handle._training_state = HandleTrainingState.BACKWARD_PRE 661 662 if handle._needs_pre_backward_unshard: 663 # If the handles have been prefetched, then there is no need to 664 # call `_unshard()` again 665 if not handle._prefetched: 666 _unshard( 667 state, 668 handle, 669 state._unshard_stream, 670 state._pre_unshard_stream, 671 ) 672 # Don't wait during trace 673 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 674 state._device_handle.current_stream().wait_stream(state._unshard_stream) 675 676 # Set this to `False` to ensure that a mistargeted prefetch does not 677 # actually unshard these handles 678 handle._needs_pre_backward_unshard = False 679 with torch.profiler.record_function( 680 "FullyShardedDataParallel._pre_backward_prefetch" 681 ): 682 _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) 683 handle.prepare_gradient_for_backward() 684 handle._ran_pre_backward_hook = True 685 return grad 686 687 688@no_type_check 689@torch.no_grad() 690def _post_backward_hook( 691 state: _FSDPState, 692 handle: FlatParamHandle, 693 flat_param, 694 *unused: Any, 695): 696 """ 697 Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. 698 699 Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the 700 unsharded gradient for the local batch. 701 702 Postcondition: 703 - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced 704 unsharded gradient. 705 - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded 706 gradient (accumulating with any existing gradient). 707 """ 708 _log_post_backward_hook(state, handle, logger) 709 flat_param = handle.flat_param 710 flat_param._post_backward_called = True 711 with torch.autograd.profiler.record_function( 712 "FullyShardedDataParallel._post_backward_hook" 713 ): 714 _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) 715 # For multiple applications of reentrant AC across submodules sharing 716 # the same `FlatParameter`, the post-backward hook may run multiple 717 # times in one backward, in which case we permit the state to already 718 # be in `BACKWARD_POST`. 719 _p_assert( 720 handle._training_state 721 in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), 722 f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", 723 ) 724 handle._training_state = HandleTrainingState.BACKWARD_POST 725 726 if flat_param.grad is None: 727 return 728 if flat_param.grad.requires_grad: 729 raise RuntimeError("FSDP does not support gradients of gradients") 730 731 _post_backward_reshard(state, handle) 732 if not state._sync_gradients: 733 if handle._use_orig_params: 734 handle._use_unsharded_grad_views() 735 return 736 737 # Wait for all ops in the current stream (e.g. gradient computation) to 738 # finish before reduce-scattering the gradient 739 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 740 state._post_backward_stream.wait_stream( 741 state._device_handle.current_stream() 742 ) 743 744 with state._device_handle.stream(state._post_backward_stream): 745 autograd_computed_grad = flat_param.grad.data 746 if ( 747 not _low_precision_hook_enabled(state) 748 and flat_param.grad.dtype != handle._reduce_dtype 749 # If we are forcing full precision but communicating grads 750 # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient. 751 and not handle._force_full_precision 752 ): 753 flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype) 754 if handle.uses_sharded_strategy: 755 _reduce_grad(state, handle) 756 else: 757 _reduce_grad_no_shard(state, handle) 758 # Since the unsharded gradient is produced in the computation 759 # stream and consumed in the post-backward stream, inform the 760 # caching allocator (before it goes out of scope) 761 _no_dispatch_record_stream( 762 autograd_computed_grad, state._post_backward_stream 763 ) 764 765 766def _post_backward_reshard_only_hook( 767 state: _FSDPState, 768 handle: FlatParamHandle, 769 *unused: Any, 770) -> None: 771 with torch.profiler.record_function( 772 "FullyShardedDataParallel._post_backward_hook_reshard_only" 773 ): 774 # `_pre_backward_hook` may not get executed 775 # if forward output does not require grad 776 # overwrite IDLE state for post-backward prefetching 777 state.training_state = TrainingState.FORWARD_BACKWARD 778 handle._training_state = HandleTrainingState.BACKWARD_POST 779 _post_backward_reshard(state, handle) 780 781 782def _post_backward_reshard( 783 state: _FSDPState, 784 handle: FlatParamHandle, 785 *unused: Any, 786) -> None: 787 free_unsharded_flat_param = _should_free_in_backward(state, handle) 788 _reshard(state, handle, free_unsharded_flat_param) 789 790 # TODO: Post-backward prefetching does not support the multiple handles 791 # per module case since the post-backward hook runs per handle, not per 792 # group of handles. 793 with torch.profiler.record_function( 794 "FullyShardedDataParallel._post_backward_prefetch" 795 ): 796 _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) 797 798 799@no_type_check 800def _should_free_in_backward( 801 state: _FSDPState, 802 handle: FlatParamHandle, 803) -> bool: 804 """ 805 Returns whether FSDP should free the unsharded flat parameter in the 806 post-backward or not. 807 """ 808 if not handle.uses_sharded_strategy: 809 return False 810 # If not syncing gradients, then we do not free for strategies that do not 811 # reshard after forward as a *heuristic* to tradeoff higher memory for 812 # higher throughput. 813 return ( 814 state._sync_gradients 815 or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES 816 ) 817 818 819@no_type_check 820def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None: 821 """ 822 For sharded strategies, this runs gradient reduction, sharded gradient 823 accumulation if needed, and the post-reduction callback. 824 """ 825 flat_param = handle.flat_param 826 uses_hybrid_sharded_strategy = handle._sharding_strategy in ( 827 HandleShardingStrategy.HYBRID_SHARD, 828 HandleShardingStrategy._HYBRID_SHARD_ZERO2, 829 ) 830 # We clear `.grad` to permit multiple backwards. This avoids a race where 831 # the second backward pass computation precedes ahead of the first backward 832 # pass reduction, which is possible since the reduction is issued in a 833 # separate stream and is async and would result in reducing the wrong 834 # gradient. 835 unsharded_grad = flat_param.grad.data 836 flat_param.grad = None 837 padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors( 838 state, unsharded_grad 839 ) 840 if state._comm_hook is None: # default path 841 _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor) 842 pg = ( 843 handle._fake_process_group 844 if handle._use_fake_reduce 845 else state.process_group 846 ) 847 dist.reduce_scatter_tensor( 848 new_sharded_grad, 849 padded_unsharded_grad, 850 group=pg, 851 ) 852 if uses_hybrid_sharded_strategy: 853 # Don't wait during trace 854 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 855 state._all_reduce_stream.wait_stream(state._post_backward_stream) 856 with state._device_handle.stream(state._all_reduce_stream): 857 # Since the new sharded gradient is produced in the post- 858 # backward stream and consumed in the all-reduce stream, 859 # inform the caching allocator 860 _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream) 861 dist.all_reduce(new_sharded_grad, group=state._inter_node_pg) 862 _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) 863 grad_to_offload = _accumulate_sharded_grad( 864 state, handle, new_sharded_grad 865 ) 866 _post_reduce_grad_callback(state, handle, grad_to_offload) 867 return 868 _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) 869 else: 870 state._comm_hook( 871 state._comm_hook_state, padded_unsharded_grad, new_sharded_grad 872 ) 873 # NOTE: HSDP variants do not support communication hook. 874 grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad) 875 _post_reduce_grad_callback(state, handle, grad_to_offload) 876 877 878@no_type_check 879def _get_reduce_scatter_tensors( 880 state: _FSDPState, unsharded_grad: torch.Tensor 881) -> Tuple[torch.Tensor, torch.Tensor]: 882 """ 883 Returns the input and output tensors to reduce-scatter, respectively. 884 """ 885 chunks = list(unsharded_grad.chunk(state.world_size)) 886 numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel() 887 padded_unsharded_grad = ( 888 F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad 889 ) 890 new_sharded_grad = torch.empty_like(chunks[0]) # padded 891 return padded_unsharded_grad, new_sharded_grad 892 893 894@no_type_check 895def _accumulate_sharded_grad( 896 state: _FSDPState, 897 handle: FlatParamHandle, 898 sharded_grad: torch.Tensor, 899) -> torch.Tensor: 900 """ 901 Accumulates the reduce-scattered sharded gradient with any existing sharded 902 gradient if needed, returning the gradient to offload (if CPU offloading is 903 enabled). 904 """ 905 flat_param = handle.flat_param 906 _cast_grad_to_param_dtype(state, sharded_grad, flat_param) 907 # Save the sharded gradient in `_saved_grad_shard` to support gradient 908 # accumulation -- for multiple backwards, the gradient reductions may 909 # happen in arbitrary order 910 accumulate_grad = hasattr(flat_param, "_saved_grad_shard") 911 if accumulate_grad: 912 _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard) 913 flat_param._saved_grad_shard += sharded_grad 914 else: 915 flat_param._saved_grad_shard = sharded_grad 916 grad_to_offload = flat_param._saved_grad_shard 917 return grad_to_offload 918 919 920@no_type_check 921def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None: 922 """ 923 For no-shard, this runs gradient reduction (which directly covers any 924 gradient accumulation implicitly) and the post-reduction callback. 925 """ 926 flat_param = handle.flat_param 927 if state._comm_hook is None: # default path 928 _div_if_needed(flat_param.grad, state._gradient_predivide_factor) 929 dist.all_reduce(flat_param.grad, group=state.process_group) 930 _div_if_needed(flat_param.grad, state._gradient_postdivide_factor) 931 else: 932 state._comm_hook(state._comm_hook_state, flat_param.grad) 933 # For `NO_SHARD`, we can keep the low precision gradients by simply 934 # omitting the cast altogether 935 if not handle._keep_low_precision_grads: 936 _cast_grad_to_param_dtype(state, flat_param.grad, flat_param) 937 grad_to_offload = flat_param.grad.data 938 _post_reduce_grad_callback(state, handle, grad_to_offload) 939 940 941@no_type_check 942def _post_reduce_grad_callback( 943 state: _FSDPState, 944 handle: FlatParamHandle, 945 # Additional arguments needed for the callback logic 946 grad_to_offload: torch.Tensor, 947): 948 """ 949 This callback captures any logic to run after the gradient reduction 950 finishes. Currently, this offloads the gradient to CPU if CPU offloading is 951 enabled and uses sharded gradient views if ``use_orig_params=True``. 952 """ 953 _offload_grad(state, handle, grad_to_offload) 954 _post_backward_use_sharded_grad_views(handle) 955 956 957@no_type_check 958def _offload_grad( 959 state: _FSDPState, 960 handle: FlatParamHandle, 961 grad_to_offload: torch.Tensor, 962): 963 if not handle._offload_params: 964 return 965 # Offload the gradient to CPU to ensure parameters and gradients are on the 966 # same device as required by the optimizer 967 # TODO: Investigate why `NO_SHARD` breaks correctness when using 968 # `non_blocking=True` here. 969 # TODO (rohan-varma): When CPU offload and optimizer overlap, 970 # non_blocking=True won't work since the copy may have not finished before 971 # the optimizer step executes on CPU. If we want to use non-blocking=True 972 # here, we'll have to synchronize before using result on CPU. 973 non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward 974 handle.flat_param._cpu_grad.copy_( 975 grad_to_offload.detach(), non_blocking=non_blocking 976 ) # synchronized in the post-backward callback 977 # Since the gradient being offloaded may have been produced in the 978 # computation stream and is being consumed here in the post-backward 979 # stream, inform the caching allocator 980 _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream) 981 982 983@no_type_check 984def _post_backward_use_sharded_grad_views(handle: FlatParamHandle): 985 if not handle._use_orig_params: 986 return 987 # Since the handle's `FlatParameter` completed its gradient computation, we 988 # should reset the gradient noneness mask 989 handle._reset_is_grad_none() 990 # Delay using sharded gradient views until after the reduce-scatter instead 991 # of immediately after resharding 992 handle._use_sharded_grad_views() 993 if handle._has_optim_in_backward: 994 handle.prepare_gradient_for_optim() 995 for orig_param in handle.flat_param._params: 996 # Check for `None` gradient to filter parameters not in the rank 997 if orig_param.grad is not None and hasattr( 998 orig_param, "_in_backward_optimizers" 999 ): 1000 # TODO (rohan-varma): For CPU offload, this unfortunately 1001 # operates on CPU because the parameters and gradients have 1002 # already been offloaded. We should run this on GPU after 1003 # refactoring. 1004 for optim in orig_param._in_backward_optimizers: 1005 optim.step() 1006 1007 optim.zero_grad(set_to_none=True) 1008 handle._reset_flat_param_grad_info_if_needed() 1009 if handle._offload_params: 1010 handle.flat_param._cpu_grad = None 1011 1012 1013def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None: 1014 if div_factor > 1: 1015 tensor.div_(div_factor) 1016 1017 1018@no_type_check 1019def _cast_grad_to_param_dtype( 1020 state: _FSDPState, 1021 sharded_grad: torch.Tensor, 1022 param: FlatParameter, 1023): 1024 """ 1025 Casts ``sharded_grad`` back to the full parameter dtype so that the 1026 optimizer step runs with that dtype. This performs an actual cast if 1027 1. parameters were in reduced precision during the forward since then 1028 gradients would be in that reduced precision, or 1029 2. parameters were not in reduced precision but gradients were in 1030 reduced precision for communication. 1031 However, if a low precision communication hook is registered, then this 1032 dtype cast happens in the hook instead. 1033 """ 1034 _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) 1035 if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype: 1036 low_prec_grad_data = sharded_grad.data 1037 sharded_grad.data = sharded_grad.data.to(dtype=param.dtype) 1038 # Since for `NO_SHARD`, the gradient is produced in the computation 1039 # stream and consumed here in the post-backward stream, inform the 1040 # caching allocator; for the sharded strategies, the gradient is 1041 # produced in the post-backward stream, so this `record_stream()` 1042 # should be a no-op 1043 _no_dispatch_record_stream( 1044 low_prec_grad_data, state._device_handle.current_stream() 1045 ) 1046 1047 1048def _check_grad_to_accumulate( 1049 new_sharded_grad: torch.Tensor, 1050 accumulated_grad: torch.Tensor, 1051) -> None: 1052 _p_assert( 1053 accumulated_grad.shape == new_sharded_grad.shape, 1054 "Shape mismatch when accumulating gradients: " 1055 f"existing gradient shape={accumulated_grad.shape} " 1056 f"new gradient shape={new_sharded_grad.shape}", 1057 ) 1058 _p_assert( 1059 accumulated_grad.device == new_sharded_grad.device, 1060 "Device mismatch when accumulating gradients: " 1061 f"existing gradient device={accumulated_grad.device} " 1062 f"new gradient device={new_sharded_grad.device}", 1063 ) 1064 1065 1066@no_type_check 1067def _low_precision_hook_enabled(state: _FSDPState) -> bool: 1068 return state._comm_hook in LOW_PRECISION_HOOKS 1069 1070 1071@no_type_check 1072@torch.no_grad() 1073def _post_backward_final_callback( 1074 state: _FSDPState, 1075 module: nn.Module, 1076): 1077 """ 1078 This waits for the post-backward to finish and performs some final cleanup. 1079 This runs at the end of the entire backward pass and should only be called 1080 on the root FSDP instance. 1081 """ 1082 _p_assert( 1083 state._is_root, 1084 "The post-backward callback should only be called on the root FSDP instance", 1085 ) 1086 root_state = state 1087 1088 if root_state._sync_gradients: 1089 current_stream = state._device_handle.current_stream() 1090 # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish 1091 # since it currently runs in the post-backward stream. That can be 1092 # pushed to the next forward if run in a different stream 1093 current_stream.wait_stream(root_state._post_backward_stream) 1094 if root_state._all_reduce_stream is not current_stream: # uses HSDP 1095 current_stream.wait_stream(root_state._all_reduce_stream) 1096 if root_state.cpu_offload.offload_params: 1097 # Wait for non-blocking GPU -> CPU sharded gradient copies from the 1098 # post-backward hooks to finish explicitly since CPU gradients do 1099 # not automatically synchronize with the GPU 1100 state._device_handle.current_stream().synchronize() 1101 root_state._exec_order_data.next_iter() 1102 1103 for fsdp_state in state._all_fsdp_states: 1104 _catch_all_reshard(fsdp_state) 1105 _finalize_params(fsdp_state) 1106 fsdp_state.training_state = TrainingState.IDLE 1107 handle = fsdp_state._handle 1108 if handle: 1109 handle._ran_pre_backward_hook = False 1110 handle._needs_pre_backward_unshard = False 1111 handle._post_forward_index = None 1112 handle._training_state = HandleTrainingState.IDLE 1113 handle._prefetched = False 1114 # Reset for cases like one forward and multiple backwards 1115 root_state._post_backward_callback_queued = False 1116 1117 1118@no_type_check 1119def _catch_all_reshard( 1120 state: _FSDPState, 1121) -> None: 1122 """ 1123 Reshards the parameters that may not have been resharded in the 1124 post-backward hook. This can happen when a module's output is used in the 1125 forward pass, meaning that its pre-backward hook runs (unsharding the 1126 parameter), but the post-backward hook does not run because the output was 1127 not jused in the loss computation corresponding to this backward pass. 1128 """ 1129 # Wrap with a try-except to provide a more informative traceback if an 1130 # error is raised 1131 try: 1132 if state._handle: 1133 # TODO: This already-resharded check is brittle: 1134 # https://github.com/pytorch/pytorch/issues/83956 1135 already_resharded = ( 1136 state._handle.flat_param.data_ptr() 1137 == state._handle.flat_param._local_shard.data_ptr() 1138 # If FSDP skipped using sharded views, then the flat parameter 1139 # still points to the sharded data, so we need to reshard to 1140 # use sharded views 1141 and not state._handle._skipped_use_sharded_views 1142 ) 1143 if already_resharded: 1144 return 1145 free_unsharded_flat_param = _should_free_in_backward(state, state._handle) 1146 _reshard(state, state._handle, free_unsharded_flat_param) 1147 except Exception as e: 1148 _p_assert( 1149 False, 1150 f"Got exception in the catch-all reshard for {state}: {str(e)}", 1151 raise_assertion_error=False, 1152 ) 1153 raise e 1154 1155 1156@no_type_check 1157def _finalize_params( 1158 state: _FSDPState, 1159) -> None: 1160 """Finalizes the parameters before the next iteration.""" 1161 handle = state._handle 1162 if not handle: 1163 return 1164 flat_param = handle.flat_param 1165 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1166 if hasattr(flat_param, "_post_backward_hook_handle"): 1167 pbhs_handle = flat_param._post_backward_hook_handle 1168 pbhs_handle.remove() 1169 del flat_param._post_backward_hook_handle 1170 else: 1171 if hasattr(flat_param, "_post_backward_hook_state"): 1172 post_backward_hook_state_len = len(flat_param._post_backward_hook_state) 1173 expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1 1174 _p_assert( 1175 post_backward_hook_state_len == expected_post_backward_hook_state_len, 1176 f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", 1177 ) 1178 flat_param._post_backward_hook_state[-1].remove() 1179 delattr(flat_param, "_post_backward_hook_state") 1180 if flat_param.requires_grad: 1181 if not state._sync_gradients: 1182 # Preserve the gradient accumulation state if not synchronizing 1183 # gradients: `.grad` remains the unsharded gradient from prior 1184 # `no_sync()` iterations, and `_saved_grad_shard` remains the 1185 # sharded gradient from the last synchronized iteration 1186 return 1187 if not handle._has_optim_in_backward: 1188 handle.prepare_gradient_for_optim() 1189 _p_assert( 1190 hasattr(flat_param, "_post_backward_called"), 1191 "Expects `_post_backward_called` to be set on the `FlatParameter`", 1192 ) 1193 flat_param._post_backward_called = False 1194 1195 1196@no_type_check 1197def _prefetch_handle( 1198 state: _FSDPState, 1199 current_handle: Optional[FlatParamHandle], 1200 prefetch_mode: _PrefetchMode, 1201) -> None: 1202 """ 1203 Prefetches the next handles if needed (without synchronization). An empty 1204 handles key cannot prefetch. 1205 """ 1206 if not current_handle: 1207 return 1208 handle = _get_handle_to_prefetch(state, current_handle) 1209 if not handle: 1210 return 1211 # Temporarily emulate the training state while calling `_unshard` to 1212 # ensure the correct `as_params` for `_use_unsharded_views()` 1213 prev_training_state = handle._training_state 1214 if prefetch_mode == _PrefetchMode.BACKWARD: 1215 handle._training_state = HandleTrainingState.BACKWARD_PRE 1216 elif prefetch_mode == _PrefetchMode.FORWARD: 1217 handle._training_state = HandleTrainingState.FORWARD 1218 else: 1219 raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}") 1220 # Prefetch the next set of handles without synchronizing to allow 1221 # the sync to happen as late as possible to maximize overlap 1222 _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) 1223 handle._training_state = prev_training_state 1224 handle._prefetched = True 1225 1226 1227@no_type_check 1228def _get_handle_to_prefetch( 1229 state: _FSDPState, 1230 current_handle: FlatParamHandle, 1231) -> FlatParamHandle: 1232 """ 1233 Returns a :class:`list` of the handles keys to prefetch for the next 1234 module(s), where ``current_handle`` represents the current module. 1235 1236 "Prefetching" refers to running the unshard logic early (without 1237 synchronization), and the "next" modules depend on the recorded execution 1238 order and the current training state. 1239 """ 1240 training_state = _get_training_state(current_handle) 1241 valid_training_states = ( 1242 HandleTrainingState.BACKWARD_PRE, 1243 HandleTrainingState.BACKWARD_POST, 1244 HandleTrainingState.FORWARD, 1245 ) 1246 _p_assert( 1247 training_state in valid_training_states, 1248 f"Prefetching is only supported in {valid_training_states} but " 1249 f"currently in {training_state}", 1250 ) 1251 eod = state._exec_order_data 1252 target_handle: Optional[FlatParamHandle] = None 1253 if ( 1254 training_state == HandleTrainingState.BACKWARD_PRE 1255 and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE 1256 ) or ( 1257 training_state == HandleTrainingState.BACKWARD_POST 1258 and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST 1259 ): 1260 target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle) 1261 if ( 1262 target_handle_candidate 1263 and target_handle_candidate._needs_pre_backward_unshard 1264 and not target_handle_candidate._prefetched 1265 ): 1266 target_handle = target_handle_candidate 1267 else: 1268 target_handle = None 1269 elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch: 1270 target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle) 1271 if ( 1272 target_handle_candidate 1273 and target_handle_candidate._needs_pre_forward_unshard 1274 and not target_handle_candidate._prefetched 1275 ): 1276 target_handle = target_handle_candidate 1277 else: 1278 target_handle = None 1279 1280 return target_handle 1281 1282 1283def _get_training_state( 1284 handle: FlatParamHandle, 1285) -> HandleTrainingState: 1286 """Returns the training state of the handles in ``handle``.""" 1287 _p_assert(handle, "Expects a non-empty handle") 1288 return handle._training_state 1289 1290 1291@no_type_check 1292def _register_pre_forward_hook( 1293 state: _FSDPState, 1294 module: nn.Module, 1295) -> None: 1296 """ 1297 Registers a pre-forward hook on ``module``. 1298 """ 1299 for forward_handle in state._pre_forward_handles: 1300 forward_handle.remove() 1301 state._pre_forward_handles.clear() 1302 module_param_handle = state._fully_sharded_module_to_handle.get(module, None) 1303 hook = functools.partial( 1304 _pre_forward, state, module_param_handle, _pre_forward_unshard 1305 ) 1306 state._pre_forward_handles.append( 1307 module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) 1308 ) 1309 1310 1311@no_type_check 1312def _register_post_forward_hook( 1313 state: _FSDPState, 1314 module: nn.Module, 1315) -> None: 1316 """ 1317 Registers a post-forward hook on ``module``. Even if the module has no 1318 handles, we should register the hook since it will register the module's 1319 pre-backward hook. 1320 """ 1321 for forward_handle in state._post_forward_handles: 1322 forward_handle.remove() 1323 state._post_forward_handles.clear() 1324 module_param_handle = state._fully_sharded_module_to_handle.get(module, None) 1325 hook = functools.partial( 1326 _post_forward, 1327 state, 1328 module_param_handle, 1329 _post_forward_reshard, 1330 ) 1331 state._post_forward_handles.append(module.register_forward_hook(hook)) 1332 1333 1334@no_type_check 1335def _register_root_pre_forward_hook( 1336 state: _FSDPState, 1337 module: nn.Module, 1338): 1339 """ 1340 Registers root pre-forward hook on ``module``, which should be the local 1341 FSDP root. 1342 1343 NOTE: For the current composable FSDP design, we have each application of 1344 ``fully_shard()`` to a module to indicate that that module is the local 1345 FSDP root. We may remove this assumption in the future, in which case we 1346 will need to register this root pre-forward hook on any candidate module 1347 that may be the local FSDP root. 1348 """ 1349 for forward_handle in state._root_pre_forward_handles: 1350 forward_handle.remove() 1351 state._root_pre_forward_handles.clear() 1352 hook = functools.partial(_root_pre_forward, state) 1353 state._root_pre_forward_handles.append( 1354 module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) 1355 ) 1356 1357 1358@no_type_check 1359def _register_pre_backward_hooks( 1360 state: _FSDPState, 1361 module: nn.Module, 1362 outputs: Any, 1363 handle: FlatParamHandle, 1364) -> None: 1365 """ 1366 Registers pre-backward hooks on the tensors that require gradients in the 1367 forward pass outputs ``outputs``, which were computed using the 1368 ``FlatParameter`` s of ``handles``. 1369 1370 Args: 1371 module (nn.Module): Fully sharded module (see [Note: Fully Sharded 1372 Module]). 1373 1374 Returns: 1375 Forward pass outputs with pre-backward hooks registered to tensors that 1376 require gradients. 1377 """ 1378 # If there is no gradient computation, then there is no need for 1379 # pre-backward logic 1380 if not torch.is_grad_enabled(): 1381 return outputs 1382 if state._is_root: 1383 state._post_backward_callback_queued = False # only defined on the root 1384 1385 if handle: 1386 handle._needs_pre_backward_unshard = False 1387 # Since these handles' `FlatParameter`s participated in a forward, we 1388 # conservatively assume that they will be used in the backward 1389 handle._ran_pre_backward_hook = False 1390 1391 def _register_hook(t: torch.Tensor) -> torch.Tensor: 1392 if t.requires_grad: 1393 t.register_hook( 1394 torch.utils.hooks.unserializable_hook( 1395 functools.partial(_pre_backward_hook, state, module, handle) 1396 ) 1397 ) 1398 if handle: 1399 handle._needs_pre_backward_unshard = True 1400 return t 1401 1402 return _apply_to_tensors(_register_hook, outputs) 1403 1404 1405def _register_post_backward_hook( 1406 state: _FSDPState, 1407 handle: Optional[FlatParamHandle], 1408) -> None: 1409 """ 1410 Registers post-backward hooks on the ``FlatParameter`` s' 1411 ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. 1412 1413 The ``AccumulateGrad`` object represents the last function that finalizes 1414 the ``FlatParameter`` 's gradient, so it only runs after its entire 1415 gradient computation has finished. 1416 1417 We register the post-backward hook only once in the *first* forward that a 1418 ``FlatParameter`` participates in. This relies on the ``AccumulateGrad`` 1419 object being preserved through multiple forwards. 1420 1421 NOTE: We follow this heuristic to prefer the *first* forward to target the 1422 parameter mixed precision case, where there are *separate* 1423 ``AccumulateGrad`` objects across the different forwards. (Without 1424 parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If 1425 we instead prefer the *last* forward, then the hook runs early. 1426 """ 1427 # If there is no gradient computation, then there is no need for 1428 # post-backward logic 1429 if not torch.is_grad_enabled(): 1430 return 1431 if not handle: 1432 return 1433 flat_param = handle.flat_param 1434 1435 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1436 already_registered = hasattr(flat_param, "_post_backward_hook_handle") 1437 if already_registered or not flat_param.requires_grad: 1438 return 1439 hook = functools.partial(_post_backward_hook, state, handle) 1440 hook_handle = flat_param.register_post_accumulate_grad_hook(hook) 1441 flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined] 1442 else: 1443 already_registered = hasattr(flat_param, "_post_backward_hook_state") 1444 if already_registered or not flat_param.requires_grad: 1445 return 1446 # Get the `AccumulateGrad` object 1447 temp_flat_param = flat_param.expand_as(flat_param) 1448 _p_assert( 1449 temp_flat_param.grad_fn is not None, 1450 "The `grad_fn` is needed to access the `AccumulateGrad` and " 1451 "register the post-backward hook", 1452 ) 1453 acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr] 1454 assert acc_grad is not None 1455 hook_handle = acc_grad.register_hook( 1456 functools.partial(_post_backward_hook, state, handle) 1457 ) 1458 flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] 1459 1460 1461def _register_post_backward_reshard_only_hook( 1462 state: _FSDPState, 1463 handle: Optional[FlatParamHandle], 1464 args: Tuple[Any, ...], 1465 kwargs: Dict[str, Any], 1466) -> None: 1467 """ 1468 Registers post-backward hooks to reshard flat parameters that do not 1469 require gradient. We register these using multi-post-grad hooks on the 1470 input activations to ensure that all gradients that may depend on the 1471 parameters have been computed before resharding. 1472 """ 1473 # If there is no gradient computation, then there is no need for 1474 # post-backward logic 1475 if not torch.is_grad_enabled(): 1476 return 1477 # Construct `inp_tensors` lazily to avoid CPU overhead in typical case 1478 # where each flat parameter requires gradient 1479 inp_tensors: Optional[List[torch.Tensor]] = None 1480 if not handle: 1481 return 1482 flat_param = handle.flat_param 1483 1484 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1485 already_registered = hasattr(flat_param, "_post_backward_hook_handle") 1486 else: 1487 already_registered = hasattr(flat_param, "_post_backward_hook_state") 1488 1489 if already_registered or flat_param.requires_grad: 1490 return 1491 if inp_tensors is None: 1492 args_flat = pytree.arg_tree_leaves(*args, **kwargs) 1493 inp_tensors = [ 1494 obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad 1495 ] 1496 assert inp_tensors is not None # mypy 1497 hook_handle = register_multi_grad_hook( 1498 inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle) 1499 ) 1500 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1501 flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment] 1502 else: 1503 flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment] 1504 1505 1506@no_type_check 1507def _register_post_backward_final_callback( 1508 state: _FSDPState, module: nn.Module 1509) -> None: 1510 """ 1511 Registers the post-backward final callback that runs at the end of the 1512 backward pass. This should be called from the root FSDP instance at the 1513 beginning of the pre-backward. 1514 """ 1515 _p_assert( 1516 state._is_root, 1517 "Only the root FSDP instance should register the post-backward callback", 1518 ) 1519 if state._post_backward_callback_queued: 1520 return 1521 _assert_in_training_states(state, [TrainingState.IDLE]) 1522 # Trace does not need this callback 1523 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1524 state._post_backward_callback_queued = True 1525 Variable._execution_engine.queue_callback( 1526 functools.partial(_post_backward_final_callback, state, module) 1527 ) 1528 1529 1530def _wait_for_computation_stream( 1531 computation_stream: torch.Stream, 1532 unshard_stream: torch.Stream, 1533 pre_unshard_stream: torch.Stream, 1534): 1535 """ 1536 Has the unshard and pre-unshard streams wait for the computation stream. 1537 For example, this should be called in the FSDP root's pre-forward to 1538 respect optimizer step computation. 1539 """ 1540 # Tracing does not need to wait 1541 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 1542 return 1543 unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] 1544 # Having the pre-all-gather stream wait for the current stream even if we 1545 # do not leverage the pre-all-gather stream is tolerable since this only 1546 # runs once per iteration 1547 pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] 1548 1549 1550def _reset_flat_param_grad_info_if_needed( 1551 handles: List[FlatParamHandle], 1552): 1553 """ 1554 Clears the original parameters' gradients if needed. This method's CPU 1555 overhead is minimal, so we may call it throughout FSDP methods, which serve 1556 as callsites to free the gradient memory earlier. 1557 """ 1558 if not isinstance(handles, list): 1559 handles = [handles] 1560 for handle in handles: 1561 if handle._use_orig_params: 1562 handle._reset_flat_param_grad_info_if_needed() 1563 1564 1565@no_type_check 1566def _get_buffers_and_dtypes_for_computation( 1567 state: _FSDPState, 1568 root_module: nn.Module, 1569) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]: 1570 """ 1571 Returns all buffers in the module tree rooted at ``root_module`` and a 1572 corresponding list of the buffer dtypes for computation. Each buffer dtype 1573 is either ``None`` if buffer mixed precision is not enabled or the buffer 1574 low precision dtype otherwise. 1575 """ 1576 _p_assert(state._is_root, "Expects the root to cast buffers") 1577 buffers: List[torch.Tensor] = [] 1578 buffer_dtypes: List[Optional[torch.dtype]] = [] 1579 visited_buffers: Set[torch.Tensor] = set() 1580 # Traverse the FSDP states bottom-up so that we prefer the owning FSDP 1581 # instance's mixed precision setting for each buffer 1582 fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules( 1583 root_module 1584 ) 1585 for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)): 1586 for buffer_name, buffer in fsdp_module.named_buffers(): 1587 if buffer in visited_buffers: 1588 continue 1589 visited_buffers.add(buffer) 1590 if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names: 1591 continue 1592 buffers.append(buffer) 1593 buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype) 1594 assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}" 1595 return buffers, buffer_dtypes 1596 1597 1598@no_type_check 1599def _get_orig_buffer_dtypes( 1600 state: _FSDPState, 1601 buffer_names: List[str], 1602) -> List[torch.dtype]: 1603 """ 1604 Returns the original buffer types of the given buffer names. 1605 """ 1606 buffer_dtypes: List[torch.dtype] = [] 1607 for buffer_name in buffer_names: 1608 _p_assert( 1609 buffer_name in state._buffer_name_to_orig_dtype, 1610 f"{buffer_name} is missing from pre-computed dict on rank " 1611 f"{state.rank}, which only has keys " 1612 f"{state._buffer_name_to_orig_dtype.keys()}", 1613 ) 1614 buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) 1615 return buffer_dtypes 1616 1617 1618def _cast_buffers_to_dtype_and_device( 1619 buffers: List[torch.Tensor], 1620 buffer_dtypes: List[Optional[torch.dtype]], 1621 device: torch.device, 1622) -> None: 1623 """ 1624 Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them 1625 to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the 1626 corresponding buffer is only moved to ``device``. 1627 """ 1628 _p_assert( 1629 buffer_dtypes is None or len(buffers) == len(buffer_dtypes), 1630 f"Expects `buffers` and `buffer_dtypes` to have the same length if " 1631 f"`buffer_dtypes` is specified but got {len(buffers)} and " 1632 f"{len(buffer_dtypes)}", 1633 ) 1634 for buffer, buffer_dtype in zip(buffers, buffer_dtypes): 1635 if not torch.is_floating_point(buffer) or buffer_dtype is None: 1636 buffer.data = buffer.to(device=device) 1637 else: 1638 buffer.data = buffer.to(device=device, dtype=buffer_dtype) 1639