1# mypy: allow-untyped-defs 2""" 3This file includes private common utilities for FSDP. 4""" 5import logging 6import traceback 7import warnings 8import weakref 9from enum import auto, Enum 10from functools import partial 11from typing import ( 12 Any, 13 Callable, 14 cast, 15 Dict, 16 Generator, 17 Iterable, 18 List, 19 no_type_check, 20 Optional, 21 Set, 22 Tuple, 23 Type, 24 TYPE_CHECKING, 25) 26 27import torch 28import torch.distributed as dist 29import torch.distributed.fsdp._flat_param as flat_param_file 30import torch.nn as nn 31from torch.distributed._composable_state import _get_module_state, _State 32from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 33 _CHECKPOINT_PREFIX, 34) 35from torch.distributed.utils import _apply_to_tensors 36from torch.utils._mode_utils import no_dispatch 37 38from .api import ( 39 FullOptimStateDictConfig, 40 FullStateDictConfig, 41 OptimStateDictConfig, 42 ShardingStrategy, 43 StateDictConfig, 44 StateDictType, 45) 46 47 48if TYPE_CHECKING: 49 from torch.distributed.device_mesh import DeviceMesh 50 from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions 51 52 from ._flat_param import FlatParamHandle 53 54FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" 55FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." 56FSDP_FLATTENED = "_fsdp_flattened" 57 58# Save a global mapping from module to its input tensor dtype to be populated 59# during the forward pre-hook and consumed in the forward post-hook when 60# overriding a module's mixed precision 61# NOTE: We currently take the last input tensor's dtype in the case of multiple 62# floating-point input tensors, which may be incorrect. However, since there is 63# not a 1:1 correspondence between input and output tensors, we must use *some* 64# heuristic like this to predict the desired output dtype. 65_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 66 67 68class _FSDPDeviceHandle: 69 """ 70 This is a simple abstraction for FSDP computing devices, 71 which enables custom backends that implement CUDA-like 72 semantics to be integrated with FSDP. 73 """ 74 75 def __init__(self, device: torch.device, backend: Any = None): 76 if backend is None: 77 try: 78 self.__backend = getattr(torch, device.type) 79 self.__device = device 80 except AttributeError as exc: 81 raise AttributeError( 82 f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'." 83 ) from exc 84 else: 85 self.__backend = backend 86 87 @classmethod 88 def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": 89 """ 90 Return a device handle corresponding to the device, and through this handle, 91 operations with the same semantics as CUDA can be performed on the device. 92 Just return torch.cuda if the device is cuda to make attribute-access faster. 93 Custom backend must first register a module with the same name with {device.type} on torch. 94 """ 95 if device.type == "cuda": 96 return cast(_FSDPDeviceHandle, torch.cuda) 97 elif device.type == "mtia": 98 return cast(_FSDPDeviceHandle, torch.mtia) 99 return cls(device) 100 101 def __getattr__(self, __name: str) -> Any: 102 try: 103 return getattr(self.__backend, __name) 104 except AttributeError as exc: 105 raise AttributeError( 106 f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{__name}'" 107 ) from exc 108 109 110class _UninitializedDeviceHandle(_FSDPDeviceHandle): 111 def __init__(self) -> None: 112 pass 113 114 def __getattribute__(self, __name: str) -> Any: 115 raise RuntimeError("Trying to use an uninitialized device handle.") 116 117 118class _FSDPState(_State): 119 def __init__(self) -> None: 120 # TODO: Move all the attributes to this class to enable typing for 121 # FSDP/fully_shard. 122 self._ignored_modules: Set[nn.Module] = set() 123 self._ignored_params: Set[nn.Parameter] = set() 124 # Buffer names are cleaned (without wrapper prefixes) 125 self._ignored_buffer_names: Set[str] = set() 126 self.process_group: Optional[dist.ProcessGroup] = None 127 self.rank: int = -1 128 self.world_size: int = -1 129 self._device_mesh: Optional[DeviceMesh] = None 130 self.sharding_strategy = ShardingStrategy.FULL_SHARD 131 self._use_orig_params: bool = False 132 self.training_state = TrainingState.IDLE 133 self._unshard_params_ctx: Dict[nn.Module, Generator] = {} 134 self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT 135 self._state_dict_config: StateDictConfig = FullStateDictConfig() 136 self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig() 137 self._is_root: Optional[bool] = None 138 self._handle: Optional[flat_param_file.FlatParamHandle] = None 139 self._fully_sharded_module_to_handle: Dict[ 140 nn.Module, Optional[flat_param_file.FlatParamHandle] 141 ] = {} 142 self.compute_device: Optional[torch.device] = None 143 self._gradient_predivide_factor: int = 0 144 self._gradient_postdivide_factor: int = 0 145 self._comm_hook: Optional[Callable] = None 146 self._comm_hook_state: Optional[Any] = None 147 self._unshard_event: Optional[torch.Event] = None 148 # Abstract device handle for fsdp compute device. For now, 149 # the compute device must implement cuda semantics used by fsdp 150 self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() 151 # All following attributes should only be used for root states: 152 # Save these static lists to avoid the repeated tree traversals 153 self._all_fsdp_states: List[_FSDPState] = [] 154 self._all_handles: List[flat_param_file.FlatParamHandle] = [] 155 self._fsdp_extension: Optional[FSDPExtensions] = None 156 157 158def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: 159 state = _get_module_state(module) 160 if state is None or not isinstance(state, _FSDPState): 161 return None 162 return state 163 164 165def _get_module_fsdp_state_if_fully_sharded_module( 166 module: nn.Module, 167) -> Optional[_FSDPState]: 168 state = _get_module_fsdp_state(module) 169 if state is None: 170 return None 171 if state == module: # FullyShardedDataParallel module case. 172 return state 173 if module in state._fully_sharded_module_to_handle: # fully_shard case. 174 return state 175 return None 176 177 178class TrainingState(Enum): 179 """ 180 An enum that indicates the state of a ``FullyShardedDataParallel` instance. 181 """ 182 183 IDLE = auto() 184 FORWARD_BACKWARD = auto() 185 SUMMON_FULL_PARAMS = auto() 186 187 188class HandleTrainingState(Enum): 189 """ 190 An enum that indicates the state of a ``FlatParamHandle`. 191 """ 192 193 IDLE = auto() 194 FORWARD = auto() 195 BACKWARD_PRE = auto() 196 BACKWARD_POST = auto() 197 SUMMON_FULL_PARAMS = auto() 198 199 200def _is_composable(state: _FSDPState): 201 # TODO: This is a temporary hack for differentiate between code paths. 202 return not isinstance(state, nn.Module) 203 204 205@no_type_check 206def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]: 207 """ 208 Returns the ``FlatParamHandle`` s corresponding to ``module``. This is 209 the handle that contains some parameter in ``module``. 210 """ 211 if _is_composable(state): 212 # A valid FSDP state may have no managed parameters and hence no 213 # handles, meaning no entry in `_fully_sharded_module_to_handles` 214 if state._handle is None: 215 return None 216 assert ( 217 module in state._fully_sharded_module_to_handle 218 ), f"Expects a fully sharded module but got {module} on rank {state.rank}" 219 return state._fully_sharded_module_to_handle[module] 220 else: 221 # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. 222 return module._handle 223 224 225@no_type_check 226def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: 227 """Returns if ``module`` has parameters managed by FSDP.""" 228 return _module_handle(state, module) is not None 229 230 231def _get_sharding_strategy(handle): 232 """ 233 Returns the sharding strategy of the handle. 234 """ 235 return handle._sharding_strategy if handle else None 236 237 238def clean_tensor_name(tensor_name: str) -> str: 239 """ 240 Cleans the parameter or buffer name by removing any module wrapper 241 prefixes. 242 """ 243 tensor_name = tensor_name.replace(FSDP_PREFIX, "") 244 # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as 245 # it couples `CheckpointWrapper` and FSDP and also does not scale for more 246 # module wrappers. 247 tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "") 248 return tensor_name 249 250 251def _set_fsdp_flattened(tensor: torch.Tensor) -> None: 252 """ 253 Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to 254 avoid re-flattening it during nested construction. 255 """ 256 setattr(tensor, FSDP_FLATTENED, True) 257 258 259def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: 260 """Returns if ``tensor`` has been marked as flattened by FSDP.""" 261 return getattr(tensor, FSDP_FLATTENED, False) 262 263 264def _named_parameters_with_duplicates( 265 module: nn.Module, **kwargs: Any 266) -> List[Tuple[str, nn.Parameter]]: 267 """ 268 This API is required as some modules overwrite `named_parameters()` but do not support 269 `remove_duplicate`. 270 """ 271 assert ( 272 "remove_duplicate" not in kwargs 273 ), "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." 274 kwargs["remove_duplicate"] = False 275 try: 276 ret = list(module.named_parameters(**kwargs)) 277 except AssertionError as e: 278 kwargs.pop("remove_duplicate") 279 ret = list(module.named_parameters(**kwargs)) 280 return ret 281 282 283def _get_param_to_fqns( 284 model: torch.nn.Module, 285 dedup_shared_params: bool = True, 286) -> Dict[nn.Parameter, List[str]]: 287 """ 288 Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here, 289 we use canonical to mean the fully-qualified name assigned to the parameter 290 based on its position in the original nn.Module hierarchy before any wrapper 291 or parallelism has been applied to it. This is in contrast to FQNs that may be 292 generated after parallelisms or wrappers have been applied to the model. 293 294 Each normal parameter maps to a singleton list containing its FQN, while each 295 ``FlatParameter`` maps to a list of its original parameter FQNs, which may 296 have length greater than one. All FQNs are prefixed starting from ``model``. 297 298 In the case where FSDP was applied with ``use_orig_params=True``, there should be no 299 ``FlatParameter`` s registered to the model's modules and this mapping will only 300 contain mappings from ``nn.Parameter`` s to singleton FQN lists. 301 302 It is only in the case where FSDP was applied with ``use_orig_params=False`` where 303 a ``FlatParameter`` will be registered in place of the original parameters and there 304 will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the 305 original parameters. 306 307 Args: 308 model (torch.nn.Module): Root module (which may or may not be a 309 :class:`FullyShardedDataParallel` instance). 310 dedup_shared_params (bool): For shared parameters, if ``True``, only 311 includes the FQNs corresponding to the first encounter of the 312 shared parameter in the module traversal; if ``False``, then 313 includes the FQNs across all encounters. (Default: ``True``) 314 """ 315 316 def module_fn(module, prefix, tree_level, param_to_fqns): 317 for param_name, param in _named_parameters_with_duplicates( 318 module, recurse=False 319 ): 320 local_fqns = ( 321 param._fqns 322 if isinstance(param, flat_param_file.FlatParameter) 323 else [param_name] 324 ) # prefixed from `module` 325 global_fqns = [ 326 clean_tensor_name(prefix + name) for name in local_fqns 327 ] # prefixed from the top level `model` (i.e. including `prefix`) 328 is_shared_param = param in param_to_fqns 329 if not is_shared_param: 330 param_to_fqns[param] = global_fqns 331 else: 332 if isinstance(param, flat_param_file.FlatParameter): 333 # DMP overwrites `named_parameters` and skip (advance to 334 # the next child module) the wrapped_module (e.g., 335 # _dmp_wrapped_module and _fsdp_wrapped_module). When a user 336 # calls `named_child` to traverse the module recursively and 337 # calls `named_parameters` with `recurse=False`, parameters 338 # will be traversed more than once. 339 # This hack is specified designed for DMP + FSDP. We 340 # overwrite the flat_parameters traversal result to only obtain 341 # the last one, which happens to be the correct one. 342 # 343 # TODO: Remove this hack once DMP + FSDP is not supported. 344 warnings.warn( 345 "FlatParameter is being traversed more than once. " 346 "This case should only happen when using " 347 "DistributedModelParallel with FullyShardedDataParallel." 348 ) 349 param_to_fqns[param] = global_fqns 350 elif not dedup_shared_params: 351 param_to_fqns[param].extend(global_fqns) 352 353 def return_fn(param_to_fqns): 354 return param_to_fqns 355 356 param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {} 357 return _apply_to_modules( 358 model, 359 module_fn, 360 return_fn, 361 [key for key, _ in _named_parameters_with_duplicates(model)], 362 param_to_unflat_param_names, 363 ) 364 365 366@no_type_check 367def _log_post_backward_hook( 368 state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger 369) -> None: 370 # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for. 371 # Below logging of module names this post-bwd hook fires for can help debug certain 372 # cases where hooks don't fire, such as under certain activation checkpoint configs. 373 if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO: 374 param_fqns = _get_handle_fqns_from_root(state, handle) 375 logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns) 376 377 378@no_type_check 379def _get_handle_fqns_from_root( 380 state: _FSDPState, handle: "FlatParamHandle" 381) -> Optional[List[str]]: 382 if handle is None: 383 return None 384 param_to_fqn = state._exec_order_data.param_to_fqn 385 handle_params = handle.flat_param._params # only populated for use_orig_params 386 param_fqns = [ 387 fqn for fqn_list in [param_to_fqn[p] for p in handle_params] for fqn in fqn_list 388 ] 389 return param_fqns 390 391 392def _apply_to_modules( 393 root_module: torch.nn.Module, 394 module_fn: Callable, 395 return_fn: Callable, 396 filter_fqns: Optional[List[str]] = None, 397 *args, 398 **kwargs, 399): 400 """ 401 Performs a pre-order traversal of the modules in the hierarchy rooted at 402 ``root_module``, applying ``module_fn`` at each module and finally 403 returning a value using ``return_fn``. The traversal constructs the full 404 module prefix name (e.g. "module.submodule." just like in model state dict) 405 and makes that available to ``module_fn``. 406 407 ``filter_fqns`` is used because some module may have its own prefix similar 408 to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten 409 to remove the prefix. 410 """ 411 412 def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): 413 # Call the module function before recursing over children (pre-order) 414 module_fn(module, prefix, tree_level, *args, **kwargs) 415 for submodule_name, submodule in module.named_children(): 416 if submodule is None: 417 continue 418 new_prefix = prefix + submodule_name + "." 419 new_tree_level = tree_level + 1 420 if filter_fqns is not None: 421 for fqn in filter_fqns: 422 if fqn.startswith(new_prefix): 423 break 424 else: 425 # DMP's named_parameter() will mess up the traversal with 426 # ``named_children`` + `named_parameter(recurse=False)``. 427 # This hack is a must to make the traversal work. 428 # TODO: Remove this hack once DMP + FSDP is not supported. 429 # It turns out that recursive wrapping may trigger this as 430 # well. 431 if ( 432 submodule_name == "_fsdp_wrapped_module" 433 or submodule_name == "_dmp_wrapped_module" 434 ): 435 new_prefix = prefix 436 elif submodule_name == "module": 437 new_prefix = prefix 438 f(submodule, new_prefix, new_tree_level, *args, **kwargs) 439 440 f(root_module, "", 0, *args, **kwargs) 441 return return_fn(*args, **kwargs) 442 443 444@no_type_check 445def _assert_in_training_states( 446 state: _FSDPState, 447 training_states: List[TrainingState], 448) -> None: 449 """Asserts that FSDP is in the states ``_training_states``.""" 450 # Raise a `ValueError` instead of using `assert` to ensure that these 451 # logical assertions run even if `assert`s are disabled 452 if state.training_state not in training_states: 453 msg = ( 454 f"expected to be in states {training_states} but current state is " 455 f"{state.training_state}" 456 ) 457 # Print the error on rank 0 in case this is called in the backward pass 458 if state.rank == 0: 459 if isinstance(state, nn.Module): 460 print(f"Asserting FSDP instance is: {state}") 461 print(f"ERROR: {msg}") 462 traceback.print_stack() 463 raise ValueError(msg) 464 465 466def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]: 467 """ 468 Returns: 469 Set[nn.Module]: The subset of ``modules`` that are root modules (i.e. 470 parent-less) with respect to the modules in the set itself. In other 471 words, these are the modules in ``modules`` that are not the child of 472 any other module in ``modules``. 473 """ 474 root_modules: Set[nn.Module] = set() 475 module_to_submodules = {module: set(module.modules()) for module in modules} 476 for candidate_module in modules: 477 is_root_module = True 478 for module, submodules in module_to_submodules.items(): 479 is_child_module = ( 480 candidate_module is not module and candidate_module in submodules 481 ) 482 if is_child_module: 483 is_root_module = False 484 break 485 if is_root_module: 486 root_modules.add(candidate_module) 487 return root_modules 488 489 490def _override_module_mixed_precision( 491 root: torch.nn.Module, 492 module_classes_to_override: Iterable[Type[nn.Module]], 493 wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006 494) -> Set[Type[nn.Module]]: 495 module_classes_to_override = tuple(set(module_classes_to_override)) 496 # Return a set of the actually overridden module classes 497 overridden_module_classes: Set[Type[nn.Module]] = set() 498 for mod in root.modules(): 499 if isinstance(mod, module_classes_to_override): 500 overridden_module_classes.add(type(mod)) 501 mod._wrap_overrides = wrap_override_dict # type: ignore[assignment] 502 # TODO: We need to run this mixed precision ignored module in fp32, 503 # but ensure subsequent modules, that may possibly be running with 504 # mixed precision, still receive the appropriate precision inputs 505 # without user having to adjust mixed precision config too much. 506 # As a result, we attach pre and post forward hooks to up / down 507 # cast. We should revisit this design. 508 509 def cast_fn( 510 dtype: torch.dtype, module: nn.Module, x: torch.Tensor 511 ) -> torch.Tensor: 512 if not torch.is_floating_point(x) or x.dtype == dtype: 513 return x 514 _MODULE_TO_INP_DTYPE[module] = x.dtype 515 return x.to(dtype) 516 517 def forward_pre_hook(module, args): 518 return _apply_to_tensors(partial(cast_fn, torch.float32, module), args) 519 520 def forward_post_hook(module, args, output): 521 # NOTE: If the forward did not have any floating-point tensors, 522 # then the dtype will not be set for this module, and we do not 523 # upcast the dtype. 524 if module in _MODULE_TO_INP_DTYPE: 525 old_dtype = _MODULE_TO_INP_DTYPE[module] 526 return _apply_to_tensors( 527 partial(cast_fn, old_dtype, module), output 528 ) 529 530 # We intentionally append both of these hooks so that they run after 531 # all other hooks. 532 mod.register_forward_pre_hook(forward_pre_hook, prepend=False) 533 mod.register_forward_hook(forward_post_hook, prepend=False) 534 return overridden_module_classes 535 536 537def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: 538 # FIXME record_stream doesn't work with non-cuda/mtia tensors 539 if tensor.device.type not in [ 540 "cuda", 541 "mtia", 542 torch._C._get_privateuse1_backend_name(), 543 ]: 544 return 545 546 if torch.distributed._functional_collectives.is_torchdynamo_compiling(): 547 return 548 # from @ezyang: 549 # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin 550 # Looking over the PR, it looks like this is because we don't actually support Stream arguments 551 # in torch dispatch, so it just chokes. 552 # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False), 553 # a better version of this would just be to check if there are any modes before disabling dispatch. 554 # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here. 555 tensor.record_stream(stream) 556 else: 557 with no_dispatch(): 558 tensor.record_stream(stream) 559