1# mypy: allow-untyped-defs 2import collections 3import itertools 4import os 5import warnings 6from typing import ( 7 Any, 8 Callable, 9 Deque, 10 Dict, 11 Generator, 12 Iterable, 13 Iterator, 14 List, 15 no_type_check, 16 Optional, 17 Set, 18 Tuple, 19 TYPE_CHECKING, 20 Union, 21) 22 23import torch 24import torch.distributed as dist 25import torch.distributed.fsdp._exec_order_utils as exec_order_utils 26import torch.distributed.fsdp._traversal_utils as traversal_utils 27import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file 28import torch.nn as nn 29from torch.distributed.algorithms._comm_hooks import default_hooks 30from torch.distributed.device_mesh import _mesh_resources, DeviceMesh 31from torch.distributed.distributed_c10d import _get_default_group 32from torch.distributed.fsdp._common_utils import ( 33 _FSDPDeviceHandle, 34 _FSDPState, 35 _get_module_fsdp_state, 36 _is_fsdp_flattened, 37 _named_parameters_with_duplicates, 38 clean_tensor_name, 39 TrainingState, 40) 41from torch.distributed.fsdp._flat_param import ( 42 _FSDP_USE_FULL_PREC_IN_EVAL, 43 FlatParameter, 44 FlatParamHandle, 45 HandleShardingStrategy, 46) 47from torch.distributed.fsdp._limiter_utils import _FreeEventQueue 48from torch.distributed.fsdp.api import ( 49 BackwardPrefetch, 50 CPUOffload, 51 FullOptimStateDictConfig, 52 FullStateDictConfig, 53 MixedPrecision, 54 ShardingStrategy, 55 StateDictConfig, 56 StateDictType, 57) 58from torch.distributed.fsdp.wrap import _Policy 59from torch.distributed.tensor.parallel.fsdp import DTensorExtensions 60from torch.distributed.utils import _sync_params_and_buffers 61from torch.utils._python_dispatch import is_traceable_wrapper_subclass 62 63 64if TYPE_CHECKING: 65 from torch.utils.hooks import RemovableHandle 66 67_TORCHDISTX_AVAIL = True 68try: 69 from torchdistx import deferred_init, fake # type: ignore[import] 70except ImportError: 71 _TORCHDISTX_AVAIL = False 72 73PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024) 74FSDP_SYNCED = "_fsdp_synced" 75# Specification of process groups for hybrid sharding strategies. 76HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup] 77# Overall specification of process group. 78ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]] 79 80 81# TODO (awgu): Refactor this later 82SHARDING_STRATEGY_MAP = { 83 ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD, 84 ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD, 85 ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP, 86 ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD, 87 ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2, 88} 89HYBRID_SHARDING_STRATEGIES = [ 90 ShardingStrategy.HYBRID_SHARD, 91 ShardingStrategy._HYBRID_SHARD_ZERO2, 92] 93NO_RESHARD_AFTER_FORWARD_STRATEGIES = ( 94 ShardingStrategy.SHARD_GRAD_OP, 95 ShardingStrategy._HYBRID_SHARD_ZERO2, 96) 97 98 99# NOTE: Since non-self attributes cannot be type annotated, several attributes 100# on `state` are defined first as local variables before being assigned. 101 102 103@no_type_check 104def _init_process_group_state( 105 state: _FSDPState, 106 process_group: ProcessGroupType, 107 sharding_strategy: ShardingStrategy, 108 policy: Optional[_Policy], 109 device_mesh: Optional[DeviceMesh] = None, 110) -> _FSDPState: 111 if process_group is not None and device_mesh is not None: 112 raise ValueError( 113 "Cannot pass both process_group and device_mesh at the " 114 "same time. Please just pass only one of them." 115 ) 116 is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES 117 if is_hybrid_strategy: 118 if process_group is None and policy is None and device_mesh is None: 119 # Raise an error here, since this is manual wrapping with no process group 120 # passed in, there is no way to ensure all wrapped FSDP instances use the same 121 # process groups. 122 raise ValueError( 123 f"Manual wrapping with {sharding_strategy} " 124 "requires explicit specification of process group or device_mesh." 125 ) 126 else: 127 state = _init_process_group_state_for_hybrid_shard( 128 state, process_group, device_mesh 129 ) 130 else: 131 if device_mesh: 132 state._device_mesh = device_mesh 133 state.process_group = device_mesh.get_group(mesh_dim=0) 134 else: 135 state.process_group = ( 136 process_group if process_group is not None else _get_default_group() 137 ) 138 139 state.rank = state.process_group.rank() 140 state.world_size = state.process_group.size() 141 data_parallel_world_size = state.world_size 142 if is_hybrid_strategy: 143 data_parallel_world_size *= state._inter_node_pg.size() 144 state._gradient_predivide_factor = ( 145 default_hooks.DefaultState._get_gradient_predivide_factor( 146 data_parallel_world_size 147 ) 148 ) 149 state._gradient_postdivide_factor = ( 150 data_parallel_world_size / state._gradient_predivide_factor 151 ) 152 return state 153 154 155@no_type_check 156def _init_process_group_state_for_hybrid_shard( 157 state: _FSDPState, 158 process_group: ProcessGroupType, 159 device_mesh: DeviceMesh, 160) -> _FSDPState: 161 if device_mesh: 162 if _is_valid_hybrid_shard_device_mesh(device_mesh): 163 state._device_mesh = device_mesh 164 # We currently only allow _inter_node_pg to be the outermost dimension, and the 165 # process_group(intra_node) to be the innermost dimension. 166 state._inter_node_pg = device_mesh.get_group(mesh_dim=0) 167 state.process_group = device_mesh.get_group(mesh_dim=1) 168 else: 169 raise ValueError( 170 f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" 171 ) 172 elif process_group is None: 173 default_group = _get_default_group() 174 intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( 175 default_group, state._device_handle.device_count() 176 ) 177 # we shard across intra-node 178 state.process_group = intra_node_group 179 # save _inter_node_pg to allreduce across. 180 state._inter_node_pg = inter_node_group 181 else: 182 # Check type and assign state.process_group and state._inter_node_pg. 183 if _is_valid_hybrid_shard_pg_type(process_group): 184 # Assuming that user passed in as intra node group and inter node group 185 # as documented. 186 state.process_group, state._inter_node_pg = process_group 187 else: 188 raise ValueError( 189 "Expected process_group to be passed in as either None or " 190 f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" 191 ) 192 # Create state for allreduce 193 state._inter_node_state = _get_default_comm_hook_state( 194 process_group=state._inter_node_pg, 195 ) 196 return state 197 198 199@no_type_check 200def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool: 201 return ( 202 isinstance(process_group, tuple) 203 and len(process_group) == 2 204 and all(isinstance(pg, dist.ProcessGroup) for pg in process_group) 205 ) 206 207 208@no_type_check 209def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool: 210 return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 211 212 213@no_type_check 214def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup: 215 """ 216 Return a process group across the current node. 217 218 For example, given each row is a distinct node: 219 0 1 2 3 4 5 6 7 220 8 9 10 11 12 13 14 15 221 This API would return an intra-node subgroup across 222 [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank. 223 For example, rank 3 would get [0, 1, ..., 7]. 224 """ 225 intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node) 226 return intra_node_subgroup 227 228 229@no_type_check 230def _init_inter_node_process_group( 231 global_process_group: dist.ProcessGroup, 232 num_devices_per_node: int, 233) -> dist.ProcessGroup: 234 """ 235 Return an inter-node process group where each contained rank has the same local rank. 236 237 For example, given each row is a distinct node: 238 0 1 2 3 4 5 6 7 239 8 9 10 11 12 13 14 15 240 This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth 241 depending on the process's rank. For example, rank 1 would get [1, 9], rank 5 242 would get [5, 13]. 243 """ 244 # the inter-node pg that is returned 245 inter_node_pg = None 246 sharding_backend = dist.get_backend(global_process_group) 247 world_size = dist.get_world_size(global_process_group) 248 # Assuming fully homogeneous setup 249 num_nodes = world_size // num_devices_per_node 250 my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node 251 for local_rank in range(num_devices_per_node): 252 ranks_for_inter_group = [ 253 local_rank + (i * num_devices_per_node) for i in range(num_nodes) 254 ] 255 # every rank always needs to call dist.new_group 256 grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend) 257 if local_rank == my_local_rank: 258 inter_node_pg = grp 259 260 assert ( 261 inter_node_pg is not None 262 ), f"{my_local_rank} expected to assign inter-node pg, but did not" 263 return inter_node_pg 264 265 266def _init_intra_and_inter_node_groups( 267 global_process_group: dist.ProcessGroup, 268 num_devices_per_node: int, 269) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]: 270 """ 271 Initialize intra and inter-node process groups and return the ones corresponding to this process's rank. 272 273 This function can be used to initialize process groups for ``HYBRID_SHARD`` or 274 ``_HYBRID_SHARD_ZERO2`` in FSDP. 275 This function assumes each node has an equal number of CUDA-enabled devices. 276 Returns: 277 Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group. 278 """ 279 return ( 280 _init_intra_node_process_group(num_devices_per_node), 281 _init_inter_node_process_group(global_process_group, num_devices_per_node), 282 ) 283 284 285@no_type_check 286def _init_ignored_module_states( 287 state: _FSDPState, 288 module: nn.Module, 289 ignored_modules: Optional[Iterable[torch.nn.Module]], 290 ignored_states: Union[ 291 Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] 292 ] = None, 293) -> _FSDPState: 294 if ignored_modules is not None and ignored_states is not None: 295 raise ValueError( 296 "Cannot pass both ignored_modules and ignored_states at the " 297 "same time. Please just pass ignored_states." 298 ) 299 ignored_parameters = None 300 passed_as_ignored_states = ignored_states is not None 301 if passed_as_ignored_states: 302 ignored_states_list = list(ignored_states) 303 _check_ignored_states(ignored_states_list, True) 304 else: 305 ignored_states_list = [] 306 _check_ignored_states( 307 list(ignored_modules) if ignored_modules is not None else [], False 308 ) 309 if len(ignored_states_list) > 0: 310 if isinstance(ignored_states_list[0], nn.Parameter): 311 ignored_parameters = ignored_states_list 312 else: 313 ignored_modules = ignored_states_list 314 state._ignored_modules = _get_ignored_modules(module, ignored_modules) 315 state._ignored_params = _get_ignored_params( 316 module, 317 state._ignored_modules, 318 ignored_parameters, 319 ) 320 state._ignored_buffer_names = _get_ignored_buffer_names( 321 module, 322 state._ignored_modules, 323 ) 324 # TODO: FSDP's contract for buffers is not well-defined. They are 325 # implicitly ignored for most functionality since they are not sharded; 326 # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed 327 # precision). We should formalize this contract and decide if we need to 328 # compute and store `_ignored_buffers`. 329 return state 330 331 332def _check_ignored_states( 333 ignored_states: List[Any], passed_as_ignored_states: bool 334) -> None: 335 """ 336 Check that the ignored states are uniformly parameters or uniformly modules. 337 338 We may remove this check in the future if we permit mixing. 339 """ 340 if len(ignored_states) == 0: 341 return 342 if passed_as_ignored_states: 343 all_params = all(isinstance(state, nn.Parameter) for state in ignored_states) 344 all_modules = all(isinstance(state, nn.Module) for state in ignored_states) 345 if not all_params and not all_modules: 346 # Sort for consistent ordering for unit test regex matching 347 sorted_types = sorted({type(state) for state in ignored_states}, key=repr) 348 raise ValueError( 349 "ignored_states expects all nn.Parameter or all nn.Module list " 350 f"elements but got types {sorted_types}" 351 ) 352 else: 353 if not all(isinstance(state, nn.Module) for state in ignored_states): 354 sorted_types = sorted({type(state) for state in ignored_states}, key=repr) 355 raise ValueError( 356 "ignored_modules expects nn.Module list elements but got " 357 f"types {sorted_types}" 358 ) 359 360 361@no_type_check 362def _init_device_handle( 363 state: _FSDPState, 364 module: nn.Module, 365 ignored_params: Set[nn.Parameter], 366 device_id: Optional[Union[int, torch.device]], 367) -> _FSDPState: 368 """ 369 Determine device handle used for initializing FSDP. 370 371 If a device is specified by ``device_id``, 372 then returns device handle corresponds to that device type. Otherwise, If the 373 module is already on a non-CPU device, then the device type is that non-CPU device type. 374 If the module is on CPU or meta, then the device type is the current accelerator device. 375 See the :ref:`Accelerators<accelerators>` for details. 376 377 378 This method will be called once ignored paramters was determined, as the device handle maybe needed 379 for other initialization. 380 """ 381 determined_device = None 382 if device_id is not None: 383 determined_device = ( 384 device_id 385 if isinstance(device_id, torch.device) 386 else torch.device(device_id) 387 ) 388 if determined_device is None: 389 for param in _get_orig_params(module, ignored_params): 390 if param.device.type in {"cpu", "meta"}: 391 continue 392 if determined_device is None: 393 determined_device = param.device 394 else: 395 if param.device.type != determined_device.type: 396 raise RuntimeError( 397 f"FSDP does not support modules with different device types " 398 f"but got params on {determined_device.type} and {param.device.type}" 399 ) 400 determined_device = determined_device or torch._C._get_accelerator() 401 if determined_device.type == "cpu": 402 raise RuntimeError( 403 "FSDP needs a non-CPU accelerator device, but no accelerator device is detected." 404 ) 405 406 state._device_handle = _FSDPDeviceHandle.from_device(determined_device) 407 return state 408 409 410@no_type_check 411def _init_buffer_state( 412 state: _FSDPState, 413 module: nn.Module, 414) -> _FSDPState: 415 state._buffer_names = _get_buffer_names(module) 416 # Save a mapping from clean fully-qualified buffer name (starting from 417 # `module`) to its original dtype for restoring that dtype during model 418 # checkpointing when buffer mixed precision is enabled. The names should 419 # be clean since the casting happens in a `summon_full_params()` context. 420 _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {} 421 for buffer_name, buffer in module.named_buffers(): 422 buffer_name = clean_tensor_name(buffer_name) 423 _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype 424 state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype 425 return state 426 427 428@no_type_check 429def _init_core_state( 430 state: _FSDPState, 431 sharding_strategy: Optional[ShardingStrategy], 432 mixed_precision: Optional[MixedPrecision], 433 cpu_offload: Optional[CPUOffload], 434 limit_all_gathers: bool, 435 use_orig_params: bool, 436 backward_prefetch_limit: int, 437 forward_prefetch_limit: int, 438) -> _FSDPState: 439 # We clamp the strategy to `NO_SHARD` for world size of 1 since they are 440 # currently functionally equivalent. This may change if/when we integrate 441 # FSDP with MoE. 442 if state.world_size == 1: 443 if sharding_strategy != ShardingStrategy.NO_SHARD: 444 warnings.warn( 445 "FSDP is switching to use `NO_SHARD` instead of " 446 f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since " 447 "the world size is 1." 448 ) 449 sharding_strategy = ShardingStrategy.NO_SHARD 450 elif sharding_strategy == ShardingStrategy.NO_SHARD: 451 warnings.warn( 452 "The `NO_SHARD` sharding strategy is deprecated. If having issues, " 453 "please use `DistributedDataParallel` instead.", 454 FutureWarning, 455 # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and 456 # level 3 is from the true caller 457 stacklevel=3, 458 ) 459 state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD 460 state.mixed_precision = mixed_precision or MixedPrecision() 461 if mixed_precision is not None: 462 torch._C._log_api_usage_once( 463 f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}" 464 ) 465 state._use_full_prec_in_eval = ( 466 os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" 467 ) 468 state.cpu_offload = cpu_offload or CPUOffload() 469 state.limit_all_gathers = limit_all_gathers 470 state._use_orig_params = use_orig_params 471 state.training_state = TrainingState.IDLE 472 state._is_root = None 473 state._free_event_queue = _FreeEventQueue() 474 state._debug_level = dist.get_debug_level() 475 state._exec_order_data = exec_order_utils._ExecOrderData( 476 state._debug_level, 477 backward_prefetch_limit, 478 forward_prefetch_limit, 479 ) 480 state._unshard_event = None 481 # Mapping from fully sharded module to the handles it is responsible to 482 # unshard and reshard (see [Note: Fully Sharded Module]) 483 _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {} 484 state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle 485 # Invariant: `state.params` contains exactly the `FlatParameter`s of the 486 # handles in `state._handle` 487 _handle: Optional[FlatParamHandle] = None 488 state._handle = _handle 489 params: List[FlatParameter] = [] 490 state.params = params 491 return state 492 493 494@no_type_check 495def _init_runtime_state( 496 state: _FSDPState, 497) -> _FSDPState: 498 _root_pre_forward_handles: List[RemovableHandle] = [] 499 state._root_pre_forward_handles = _root_pre_forward_handles 500 _pre_forward_handles: List[RemovableHandle] = [] 501 state._pre_forward_handles = _pre_forward_handles 502 _post_forward_handles: List[RemovableHandle] = [] 503 state._post_forward_handles = _post_forward_handles 504 state._sync_gradients = True 505 state._comm_hook = None 506 state._comm_hook_state = None 507 # Used to prevent running the pre-backward hook multiple times 508 return state 509 510 511@no_type_check 512def _init_prefetching_state( 513 state: _FSDPState, 514 backward_prefetch: BackwardPrefetch, 515 forward_prefetch: bool, 516) -> _FSDPState: 517 state.backward_prefetch = backward_prefetch 518 state.forward_prefetch = forward_prefetch 519 # The data structures use tuples of handles to generalize over the case 520 # where a module's forward involves multiple handles. 521 return state 522 523 524@no_type_check 525def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: 526 # TODO: we need to add additional check once we support FSDP + PiPPy. 527 # This check is currently sufficient, since we only support FSDP + TP. 528 root_mesh = _mesh_resources.get_root_mesh(device_mesh) 529 # if a root mesh is not the same as device_mesh, 530 # meaning the device_mesh is sliced out from the root mesh. 531 if device_mesh and root_mesh != state._device_mesh: 532 state._fsdp_extension = DTensorExtensions(state._device_handle) 533 else: 534 # We need to explicilty set _fsdp_extension to None. 535 # Otherwise, we will run into an infinite recursion when getting the attribute. 536 state._fsdp_extension = None 537 return state 538 539 540@no_type_check 541def _init_state_dict_state(state: _FSDPState) -> _FSDPState: 542 state._state_dict_type = StateDictType.FULL_STATE_DICT 543 state_dict_config: StateDictConfig = FullStateDictConfig() 544 state._optim_state_dict_config = FullOptimStateDictConfig() 545 state._state_dict_config = state_dict_config 546 unshard_params_ctx: Dict[nn.Module, Generator] = {} 547 state._unshard_params_ctx = unshard_params_ctx 548 549 return state 550 551 552def _verify_managed_params(module: nn.Module, params: List[nn.Parameter]) -> None: 553 """ 554 Verify if the parameters are accepted by FSDP. The only restriction now 555 is that the parameter cannot be a scalar tensor (param.shape == []). 556 """ 557 for param in params: 558 if len(param.shape) == 0: 559 param_name = "" 560 for name, param_ in module.named_parameters(): 561 if param is param_: 562 param_name = name 563 break 564 assert param_name 565 raise ValueError( 566 "FSDP doesn't support salar parameters. " 567 f"Change {param_name} to a 1D tensor with numel equal to 1." 568 ) 569 570 571@no_type_check 572def _init_param_handle_from_module( 573 state: _FSDPState, 574 fully_sharded_module: nn.Module, 575 device_id: Optional[Union[int, torch.device]], 576 param_init_fn: Optional[Callable[[nn.Module], None]], 577 sync_module_states: bool, 578) -> _FSDPState: 579 """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``.""" 580 _check_single_device_module(fully_sharded_module, state._ignored_params, device_id) 581 device_from_device_id = _get_device_from_device_id( 582 device_id, state.rank, state._device_handle 583 ) 584 is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module( 585 fully_sharded_module, state._ignored_params, state._ignored_modules 586 ) 587 # Materialize the module if needed 588 if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None: 589 _materialize_with_param_init_fn( 590 fully_sharded_module, param_init_fn, state._ignored_modules 591 ) 592 elif is_meta_module: 593 _materialize_meta_module( 594 fully_sharded_module, 595 device_id, 596 state._ignored_modules, 597 state._device_handle, 598 ) 599 elif is_torchdistX_deferred_init: 600 deferred_init.materialize_module( 601 fully_sharded_module, 602 check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None 603 and submodule not in state._ignored_modules, 604 ) 605 606 ignored_buffers = { 607 buffer 608 for ignored_module in state._ignored_modules 609 for buffer in ignored_module.buffers() 610 } 611 612 _move_module_to_device( 613 fully_sharded_module, 614 state._ignored_params, 615 ignored_buffers, 616 device_from_device_id, 617 ) 618 state.compute_device = _get_compute_device( 619 fully_sharded_module, 620 state._ignored_params, 621 device_from_device_id, 622 state.rank, 623 state._device_handle, 624 ) 625 626 managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params)) 627 _verify_managed_params(fully_sharded_module, managed_params) 628 if sync_module_states: 629 _sync_module_params_and_buffers( 630 fully_sharded_module, managed_params, state.process_group 631 ) 632 if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: 633 _sync_module_params_and_buffers( 634 fully_sharded_module, managed_params, state._inter_node_pg 635 ) 636 _init_param_handle_from_params(state, managed_params, fully_sharded_module) 637 return state 638 639 640@no_type_check 641def _init_param_handle_from_params( 642 state: _FSDPState, 643 params: List[nn.Parameter], 644 fully_sharded_module: nn.Module, 645): 646 if len(params) == 0: 647 return 648 handle = FlatParamHandle( 649 params, 650 fully_sharded_module, 651 state.compute_device, 652 SHARDING_STRATEGY_MAP[state.sharding_strategy], 653 state.cpu_offload.offload_params, 654 state.mixed_precision.param_dtype, 655 state.mixed_precision.reduce_dtype, 656 state.mixed_precision.keep_low_precision_grads, 657 state.process_group, 658 state._use_orig_params, 659 fsdp_extension=state._fsdp_extension, 660 ) 661 handle.shard() 662 assert not state._handle 663 state.params.append(handle.flat_param) 664 state._handle = handle 665 state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle 666 cpu_device = torch.device("cpu") 667 if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device: 668 handle.flat_param_to(cpu_device) 669 670 671def _get_ignored_modules( 672 root_module: nn.Module, 673 _ignored_modules: Optional[Iterable[torch.nn.Module]], 674) -> Set[nn.Module]: 675 """ 676 Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances. 677 678 Return the modules contained in their module 679 subtrees as a :class:`set`. Nested FSDP instances are excluded, but their 680 already-computed ignored modules are included. 681 682 ``_ignored_modules`` represents the argument passed by the user to FSDP. 683 """ 684 msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " 685 try: 686 ignored_root_modules = ( 687 set(_ignored_modules) if _ignored_modules is not None else set() 688 ) 689 except TypeError as e: 690 raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e 691 for module in ignored_root_modules: 692 if not isinstance(module, torch.nn.Module): 693 raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") 694 if _get_module_fsdp_state(module): 695 # TODO: We may relax this by taking the FSDP instance's wrapped 696 # module to provide more flexibility to the user. 697 raise ValueError("`ignored_modules` should not include FSDP modules") 698 # Treat modules that cannot compose with `fully_shard` as ignored modules, 699 # meaning that their subtrees are ignored 700 for module in root_module.modules(): 701 if not traversal_utils._composable(module): 702 ignored_root_modules.add(module) 703 # NOTE: Even if `ignored_root_modules` is empty, do not return early so 704 # that this FSDP instance can get any ignored modules from its children. 705 706 # Include child modules and exclude nested FSDP modules themselves 707 ignored_modules = { 708 child 709 for module in ignored_root_modules 710 for child in module.modules() 711 if not isinstance(child, fsdp_file.FullyShardedDataParallel) 712 } 713 if root_module in ignored_modules: 714 warnings.warn( 715 "Trying to ignore the top-level module passed into the FSDP " 716 "constructor itself will result in all parameters being " 717 f"ignored and is not well-supported: {module}" 718 ) 719 # Include nested FSDP modules' ignored modules 720 for submodule in root_module.modules(): 721 optional_fsdp_state = _get_module_fsdp_state(submodule) 722 if optional_fsdp_state is not None: 723 assert hasattr(optional_fsdp_state, "_ignored_modules") 724 ignored_modules.update(optional_fsdp_state._ignored_modules) 725 return ignored_modules 726 727 728def _get_ignored_params( 729 root_module: torch.nn.Module, 730 ignored_modules: Set[torch.nn.Module], 731 ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, 732) -> Set[torch.nn.Parameter]: 733 """ 734 Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``. 735 736 :class:`FlatParameter` s are excluded from the result. 737 """ 738 all_ignored_params: Set[torch.nn.Parameter] = set() 739 740 params_in_ignored_modules = { 741 p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) 742 } 743 744 all_ignored_params.update(params_in_ignored_modules) 745 746 if ignored_parameters is not None: 747 params_in_ignored_parameters = { 748 p for p in ignored_parameters if not _is_fsdp_flattened(p) 749 } 750 all_ignored_params.update(params_in_ignored_parameters) 751 752 # Always include nested FSDP modules' ignored parameters 753 for submodule in root_module.modules(): 754 optional_fsdp_state = _get_module_fsdp_state(submodule) 755 if optional_fsdp_state is not None: 756 assert hasattr(optional_fsdp_state, "_ignored_params") 757 all_ignored_params.update(optional_fsdp_state._ignored_params) 758 759 return all_ignored_params 760 761 762def _get_ignored_buffer_names( 763 root_module: torch.nn.Module, 764 ignored_modules: Set[torch.nn.Module], 765) -> Set[str]: 766 """Return the cleaned buffer FQNs in ``ignored_modules``.""" 767 all_ignored_buffer_names: Set[str] = set() 768 769 buffers_in_ignored_modules = { 770 buffer for m in ignored_modules for buffer in m.buffers() 771 } 772 773 all_ignored_buffer_names.update( 774 { 775 clean_tensor_name(buffer_name) 776 for buffer_name, buffer in root_module.named_buffers() 777 if buffer in buffers_in_ignored_modules 778 } 779 ) 780 781 # Always include nested FSDP modules' ignored buffer names 782 for submodule in root_module.modules(): 783 optional_fsdp_state = _get_module_fsdp_state(submodule) 784 if optional_fsdp_state is not None: 785 assert hasattr(optional_fsdp_state, "_ignored_buffer_names") 786 all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names) 787 788 return all_ignored_buffer_names 789 790 791def _get_buffer_names(root_module: nn.Module) -> Set[str]: 792 """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`.""" 793 return { 794 clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() 795 } 796 797 798def _check_single_device_module( 799 module: nn.Module, 800 ignored_params: Set[nn.Parameter], 801 device_id: Optional[Union[int, torch.device]], 802) -> None: 803 """ 804 Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``. 805 806 Thus, after this method, the 807 module must be either fully on the CPU or fully on a non-CPU device. 808 """ 809 devices = {param.device for param in _get_orig_params(module, ignored_params)} 810 # We allow module to be partially on CPU and partially on GPU if device_id is not 811 # None, since the device_id arg will result in the CPU portion being moved to 812 # GPU. This is useful in cases where part of the module may be parallelized 813 # by another algorithm and may already be on GPU. We'd like to enforce device_id 814 # to not be None, otherwise we'd flatten parameters in a mixed module which is 815 # not supported. 816 if len(devices) == 2 and torch.device("cpu") in devices: 817 if device_id is None: 818 raise RuntimeError( 819 "To support a module with both CPU and GPU params, " 820 "please pass in device_id argument." 821 ) 822 elif len(devices) > 1: 823 raise RuntimeError( 824 f"FSDP only supports single device modules but got params on {devices}" 825 ) 826 827 828def _get_device_from_device_id( 829 device_id: Optional[Union[int, torch.device]], 830 rank: int, 831 device_handle: _FSDPDeviceHandle, 832) -> Optional[torch.device]: 833 """ 834 Return a ``torch.device`` for the specified ``device_id``. 835 836 Processes ``device_id`` and returns either the corresponding device or 837 ``None`` if ``device_id`` is ``None``. 838 """ 839 if device_id is None: 840 return None 841 device = ( 842 device_id if isinstance(device_id, torch.device) else torch.device(device_id) 843 ) 844 if device.type != "cpu" and device.index is None: 845 warnings.warn( 846 f"FSDP got the argument `device_id` {device_id} on rank " 847 f"{rank}, which does not have an explicit index. " 848 f"FSDP will use the current device {device_handle.current_device()}. " 849 f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` " 850 "before FSDP initialization or pass in the explicit device " 851 "index as the `device_id` argument." 852 ) 853 device = torch.device(device_handle.current_device()) 854 return device 855 856 857def _need_to_materialize_module( 858 module: nn.Module, 859 ignored_params: Set[nn.Parameter], 860 ignored_modules: Set[nn.Module], 861) -> Tuple[bool, bool]: 862 """ 863 Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization. 864 865 At most of the returned bools can 866 be ``True``. If either is ``True``, then ``module`` needs to be 867 materialized. 868 """ 869 managed_params = list(_get_orig_params(module, ignored_params)) 870 is_meta_module = any(param.is_meta for param in managed_params) 871 # TODO: We need to establish a contract for FSDP and buffers. For now, we 872 # skip checking for meta buffers from ignored modules. We should consider 873 # refactoring the initialization holistically to avoid so many traversals. 874 for submodule in module.modules(): 875 if submodule in ignored_modules: 876 continue 877 for buf in submodule.buffers(recurse=False): 878 is_meta_module |= buf.is_meta 879 is_torchdistX_deferred_init = ( 880 not is_meta_module 881 and _TORCHDISTX_AVAIL 882 and any(fake.is_fake(param) for param in managed_params) 883 ) 884 return is_meta_module, is_torchdistX_deferred_init 885 886 887def _materialize_with_param_init_fn( 888 root_module: nn.Module, 889 param_init_fn: Callable[[nn.Module], None], 890 ignored_modules: Set[nn.Module], 891) -> None: 892 if not callable(param_init_fn): 893 raise ValueError( 894 f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}" 895 ) 896 modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) 897 for module in modules_to_materialize: 898 param_init_fn(module) 899 900 901def _materialize_meta_module( 902 root_module: nn.Module, 903 device_from_device_id: Optional[torch.device], 904 ignored_modules: Set[nn.Module], 905 device_handle: _FSDPDeviceHandle, 906): 907 # Run default meta device initialization 908 materialization_device = device_from_device_id or torch.device( 909 device_handle.current_device() 910 ) 911 modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) 912 module = None 913 try: 914 # Assume that each module's `reset_parameters()` only initializes its 915 # own parameters and not those of its children 916 with torch.no_grad(): 917 for module in modules_to_materialize: 918 # As a contract to the user, only call `reset_parameters()` if 919 # the module has directly managed parameters/buffers 920 module_state_iter = itertools.chain( 921 module.parameters(recurse=False), module.buffers(recurse=False) 922 ) 923 has_module_states = len(list(module_state_iter)) > 0 924 if has_module_states: 925 module.to_empty(device=materialization_device, recurse=False) 926 module.reset_parameters() # type: ignore[operator] 927 except BaseException as e: 928 warnings.warn( 929 "Unable to call `reset_parameters()` for module on meta " 930 f"device with error {str(e)}. Please ensure that your module of" 931 f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] 932 ) 933 raise e 934 935 936def _get_modules_to_materialize( 937 root_module: nn.Module, ignored_modules: Set[nn.Module] 938) -> List[nn.Module]: 939 # Run BFS to collect the modules to materialize via `reset_parameters()`, 940 # stopping at any module with FSDP already applied or at ignored modules. 941 modules_to_materialize: List[nn.Module] = [] 942 queue = collections.deque([root_module]) 943 visited_modules: Set[nn.Module] = {root_module} 944 while queue: 945 module = queue.popleft() 946 modules_to_materialize.append(module) 947 for child_module in module.children(): 948 if ( 949 child_module not in visited_modules 950 and _get_module_fsdp_state(child_module) is None 951 and child_module not in ignored_modules 952 ): 953 visited_modules.add(child_module) 954 queue.append(child_module) 955 return modules_to_materialize 956 957 958def _move_module_to_device( 959 module: nn.Module, 960 ignored_params: Set[nn.Parameter], 961 ignored_buffers: Set[torch.Tensor], 962 device_from_device_id: Optional[torch.device], 963) -> None: 964 """ 965 Move ``module`` depending on ``device_from_device_id`` and its current device. 966 967 This includes moving ignored modules' parameters. 968 969 - If ``device_from_device_id`` is not ``None``, then this moves 970 ``module`` to the device. 971 - If ``device_from_device_id`` is ``None``, then this does not move 972 ``module`` but warns the user if it is on CPU. 973 974 Precondition: ``_check_single_device_module()``. 975 """ 976 cpu_device = torch.device("cpu") 977 if device_from_device_id is not None: 978 # BFS from `module` without traversing any nested FSDP instances to 979 # collect the parameters/buffers that have not yet been managed 980 queue: Deque[nn.Module] = collections.deque() 981 queue.append(module) 982 params: List[nn.Parameter] = [] 983 buffers: List[torch.Tensor] = [] 984 while queue: 985 curr_module = queue.popleft() 986 # NOTE: We include a check to only move parameters/buffers that are 987 # on CPU device. If they are on a CUDA device different from the 988 # one specified by `device_id`, then this does NOT move them. This 989 # is so that we can raise an error in `_get_compute_device()`. 990 params.extend( 991 param 992 for param in curr_module.parameters(recurse=False) 993 if param.device == cpu_device 994 ) 995 buffers.extend( 996 buffer 997 for buffer in curr_module.buffers(recurse=False) 998 if buffer.device == cpu_device 999 ) 1000 for submodule in curr_module.children(): 1001 if not isinstance(submodule, fsdp_file.FullyShardedDataParallel): 1002 queue.append(submodule) 1003 params_to_move = [p for p in params if p not in ignored_params] 1004 bufs_to_move = [p for p in buffers if p not in ignored_buffers] 1005 _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id) 1006 return 1007 param = next(_get_orig_params(module, ignored_params), None) 1008 if param is not None and param.device == cpu_device: 1009 _warn_cpu_init() 1010 1011 1012def _move_states_to_device( 1013 params: List[nn.Parameter], 1014 buffers: List[torch.Tensor], 1015 device_from_device_id: Optional[torch.device], 1016) -> None: 1017 """ 1018 Move states to the specified device. 1019 1020 Precondition: ``_check_single_device_module()`` and module's parameters and 1021 buffers have been materialized if needed. 1022 """ 1023 if len(params) == 0 and len(buffers) == 0: 1024 return 1025 if len(params) > 0: 1026 current_device = params[0].device 1027 elif len(buffers) > 0: 1028 current_device = buffers[0].device 1029 cpu_device = torch.device("cpu") 1030 if device_from_device_id is not None: 1031 # Move the parameters and buffers like the `.data` code path in 1032 # `nn.Module._apply()`, which underlies `nn.Module.to()` 1033 for param in params: 1034 with torch.no_grad(): 1035 param.data = param.to(device_from_device_id) 1036 if param.grad is not None: 1037 param.grad.data = param.grad.to(device_from_device_id) 1038 for buffer in buffers: 1039 buffer.data = buffer.to(device_from_device_id) 1040 elif current_device == cpu_device: # type: ignore[possibly-undefined] 1041 _warn_cpu_init() 1042 1043 1044def _warn_cpu_init(): 1045 warnings.warn( 1046 "The passed-in `module` is on CPU and will thus have FSDP's sharding " 1047 "initialization run on CPU, which may be slower than on GPU. We " 1048 "recommend passing in the `device_id` argument for FSDP to move " 1049 "`module` to GPU for the sharding initialization. `module` must also " 1050 "be on GPU device to work with the `sync_module_states=True` flag " 1051 "since that requires GPU communication." 1052 ) 1053 1054 1055def _get_compute_device( 1056 module: nn.Module, 1057 ignored_params: Set[nn.Parameter], 1058 device_from_device_id: Optional[torch.device], 1059 rank: int, 1060 device_handle: _FSDPDeviceHandle, 1061) -> torch.device: 1062 """ 1063 Determine and return this FSDP instance's compute device. 1064 1065 If the module is already on a non-CPU device, then the compute device is that non-CPU 1066 device. If the module is on CPU, then the compute device is the current 1067 device. 1068 1069 Since this method should be called after materializing the module, any 1070 non-CPU device should not be meta device. For now, the compute device is 1071 always a CUDA or CUDA-like device with its explicit index. 1072 1073 Precondition: ``_check_single_device_module()`` and 1074 ``_move_module_to_device()``. 1075 """ 1076 param = next(_get_orig_params(module, ignored_params), None) 1077 if param is not None and param.device.type != "cpu": 1078 compute_device = param.device # Determined by model param placement 1079 else: 1080 compute_device = torch.device(device_handle.current_device()) 1081 if device_from_device_id is not None and compute_device != device_from_device_id: 1082 raise ValueError( 1083 f"Inconsistent compute device and `device_id` on rank {rank}: " 1084 f"{compute_device} vs {device_from_device_id}" 1085 ) 1086 return compute_device 1087 1088 1089# TODO: See how to deprecate! 1090def _sync_module_params_and_buffers( 1091 module: nn.Module, 1092 params: List[nn.Parameter], 1093 process_group: dist.ProcessGroup, 1094) -> None: 1095 """ 1096 Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks. 1097 1098 Precondition: ``sync_module_states == True`` and ``self.process_group`` has 1099 been set. 1100 """ 1101 module_states: List[torch.Tensor] = [] 1102 for buffer in module.buffers(): 1103 # Avoid re-synchronizing buffers in case of nested wrapping 1104 if not getattr(buffer, FSDP_SYNCED, False): 1105 setattr(buffer, FSDP_SYNCED, True) 1106 detached_buffer = buffer.detach() 1107 if is_traceable_wrapper_subclass(detached_buffer): 1108 # NOTE: Here we assume no nested subclasses, at most one level of subclass 1109 # in both model's buffers and params 1110 attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined] 1111 inner_buffers = [getattr(detached_buffer, attr) for attr in attrs] 1112 module_states.extend(inner_buffers) 1113 else: 1114 module_states.append(detached_buffer) 1115 1116 for param in params: 1117 detached_param = param.detach() 1118 if is_traceable_wrapper_subclass(detached_param): 1119 attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined] 1120 inner_params = [getattr(detached_param, attr) for attr in attrs] 1121 module_states.extend(inner_params) 1122 else: 1123 module_states.append(detached_param) 1124 1125 _check_module_states_for_sync_module_states(module_states) 1126 _sync_params_and_buffers( 1127 process_group, 1128 module_states, 1129 PARAM_BROADCAST_BUCKET_SIZE, 1130 src=0, 1131 ) 1132 1133 1134def _check_module_states_for_sync_module_states( 1135 module_states: List[torch.Tensor], 1136) -> None: 1137 if module_states and any( 1138 tensor.device == torch.device("cpu") for tensor in module_states 1139 ): 1140 raise ValueError( 1141 "The module has CPU parameters or buffers when `sync_module_states=True`, " 1142 "which requires them to be on GPU. Please specify the `device_id` argument " 1143 "or move the module to GPU before passing it to FSDP." 1144 ) 1145 1146 1147def _get_orig_params( 1148 module: nn.Module, 1149 ignored_params: Set[nn.Parameter], 1150) -> Iterator[nn.Parameter]: 1151 """ 1152 Return an iterator over the original parameters in ``module``. 1153 1154 The iterator does not return 1155 the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be 1156 present due to nested FSDP wrapping), or any original parameters already 1157 flattened (only relevant when ``use_orig_params=True``). 1158 """ 1159 param_gen = module.parameters() 1160 try: 1161 while True: 1162 param = next(param_gen) 1163 if param not in ignored_params and not _is_fsdp_flattened(param): 1164 yield param 1165 except StopIteration: 1166 pass 1167 1168 1169def _check_orig_params_flattened( 1170 fsdp_module, 1171 ignored_params: Set[nn.Parameter], 1172) -> None: 1173 """ 1174 Check that original parameters in ``fsdp_module`` have been flattened. 1175 1176 The flattened parameters are made 1177 invisible to ``named_parameters()`` for the module hierarchy rooted at 1178 ``fsdp_module``. This should be called as a sanity check after flattening 1179 the wrapped module's parameters. 1180 """ 1181 for param_name, param in _named_parameters_with_duplicates(fsdp_module): 1182 if param not in ignored_params and not _is_fsdp_flattened(param): 1183 raise RuntimeError( 1184 f"Found an unflattened parameter: {param_name}; " 1185 f"{param.size()} {param.__class__}" 1186 ) 1187 1188 1189def _get_default_comm_hook(sharding_strategy: ShardingStrategy): 1190 return ( 1191 default_hooks.allreduce_hook 1192 if sharding_strategy == ShardingStrategy.NO_SHARD 1193 else default_hooks.reduce_scatter_hook 1194 ) 1195 1196 1197def _get_default_comm_hook_state( 1198 process_group: dist.ProcessGroup, 1199) -> default_hooks.DefaultState: 1200 return default_hooks.DefaultState(process_group=process_group) 1201