1# mypy: allow-untyped-defs 2import contextlib 3import functools 4import logging 5import os 6import warnings 7from enum import auto, Enum 8from itertools import accumulate, chain 9from typing import ( 10 Any, 11 Callable, 12 cast, 13 Dict, 14 Generator, 15 Iterator, 16 List, 17 NamedTuple, 18 no_type_check, 19 Optional, 20 Sequence, 21 Set, 22 Tuple, 23 Union, 24) 25 26import torch 27import torch.distributed as dist 28import torch.nn as nn 29import torch.nn.functional as F 30from torch import Tensor 31from torch.distributed.fsdp._common_utils import ( 32 _FSDPDeviceHandle, 33 _named_parameters_with_duplicates, 34 _no_dispatch_record_stream, 35 _set_fsdp_flattened, 36 HandleTrainingState, 37) 38from torch.distributed.utils import ( 39 _alloc_storage, 40 _data_ptr_allocated, 41 _free_storage, 42 _p_assert, 43) 44from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined] 45from torch.testing._internal.distributed.fake_pg import FakeProcessGroup 46 47from ._fsdp_extensions import ( 48 _ext_post_unflatten_transform, 49 _ext_pre_flatten_transform, 50 FSDPExtensions, 51) 52 53 54__all__ = [ 55 "FlatParameter", 56 "FlatParamHandle", 57 "FlatParamShardMetadata", 58 "ParamInfo", 59 "SharedParamInfo", 60 "HandleShardingStrategy", 61] 62 63logger = logging.getLogger(__name__) 64 65 66""" 67[Note: Fully Sharded Module] 68We define the "fully sharded module" to be the original ``nn.Module`` that owns 69a ``FlatParamHandle``. It is the *single* module logically responsible for the 70*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given 71forward or backward pass. The fully sharded module should be passed to the 72``FlatParamHandle`` constructor. 73 74For the wrapper code path: 75- The ``FullyShardedDataParallel`` module wrapping the fully sharded module 76runs the unshard/reshard on behalf of the fully sharded module by overriding 77``nn.Module.forward``. 78- The fully sharded module is exactly the module passed to the 79``FullyShardedDataParallel`` constructor's ``module`` argument. 80 81For the non-wrapper code path: 82- Hooks registered on the fully sharded module run the unshard/reshard. 83- The fully sharded module may either be the direct argument to ``fully_shard`` 84or a submodule chosen by the provided wrapping policy. 85""" 86 87# Environment variable toggling whether to use unsafe `setattr()` for view 88# setting in `_use_sharded_views()` and `_use_unsharded_views()` 89# We should use 'safe' by default since it respects method overrides, but for 90# special cases such as for high CPU overhead or for intentionally bypassing 91# checks in the overrides, we may use 'unsafe'. 92_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR" 93 94# Environment variable toggling whether to check for parameter/gradient 95# writeback in case their storages change after FSDP initialization 96# We should check by default since it prevents silent correctness errors, but 97# since such changes are atypical, we may want to skip the check to save CPU 98# overhead, especially since the check happens in the pre-forward and 99# pre-backward each iteration. 100_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK" 101 102# Env var toggling whether when model is in .eval() mode, should we run in fp32 103# or the reduced precision. 104_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL" 105 106# Some value to set padding in tensors to for debuggability 107_FLAT_PARAM_PADDING_VALUE = 42 108 109# Environment variables for disabling the all-gather and reduce-scatter 110# communication ops for ablation studies. Note that without these communication 111# ops the training won't converge, and you probably need to disable correctness 112# checks in your model. 113_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER" 114_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE" 115 116 117# TODO: Define this for now to avoid circular imports. See if we can remove. 118class HandleShardingStrategy(Enum): 119 FULL_SHARD = auto() 120 SHARD_GRAD_OP = auto() 121 NO_SHARD = auto() 122 HYBRID_SHARD = auto() 123 _HYBRID_SHARD_ZERO2 = auto() 124 125 126RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( 127 HandleShardingStrategy.FULL_SHARD, 128 HandleShardingStrategy.HYBRID_SHARD, 129) 130NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( 131 HandleShardingStrategy.SHARD_GRAD_OP, 132 HandleShardingStrategy._HYBRID_SHARD_ZERO2, 133) 134 135 136class ParamInfo(NamedTuple): 137 """Information for an original parameter.""" 138 139 param_name: str # unprefixed 140 module: nn.Module 141 module_name: str 142 143 144class SharedParamInfo(NamedTuple): 145 """ 146 Additional information for a shared parameter. 147 148 For each shared parameter, we designate one module and its parameter 149 variable to be the primary owner, determined as the first one encountered 150 in the parameter walk. These are prefixed with "prim". The primary module 151 and parameter do not have their own :class:`SharedParamInfo` instance. 152 """ 153 154 param_name: str # unprefixed 155 module: nn.Module 156 module_name: str 157 prim_param_name: str # unprefixed 158 prim_module: nn.Module 159 prim_module_name: str 160 161 162class _ShardParamInfo(NamedTuple): 163 """Shard-related information for an original parameter.""" 164 165 in_shard: bool 166 # Use to index into the sharded flat parameter, e.g. 167 # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]` 168 offset_in_shard: Optional[int] 169 numel_in_shard: Optional[int] 170 # Use to get part of the parameter in the local shard from a flattened 171 # version of the unsharded parameter, e.g. 172 # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` 173 intra_param_start_idx: Optional[int] 174 intra_param_end_idx: Optional[int] # inclusive 175 176 177class FlatParamShardMetadata(NamedTuple): 178 """ 179 This holds metadata specific to this rank's shard of the flat parameter. 180 181 Attributes: 182 param_names (Tuple[str, ...]): Prefixed parameter names of this rank's 183 shard of the parameters; see :class:`FlatParameter`. 184 param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's 185 shard of the parameters; see :class:`FlatParameter`. 186 param_numels (Tuple[int, ...]): Parameter numels of this rank's shard 187 of the parameters; see :class:`FlatParameter`. 188 param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in 189 units of numels) giving this rank's part of each flattened 190 original parameter. 191 """ 192 193 param_names: Tuple[str, ...] 194 param_shapes: Tuple[torch.Size, ...] 195 param_numels: Tuple[int, ...] 196 param_offsets: Tuple[Tuple[int, int], ...] 197 198 199class _FlatParameterMeta(_ParameterMeta): 200 # Make `isinstance(t, FlatParameter)` return True for custom tensor 201 # instances that have the _is_flat_param flag for BC 202 def __instancecheck__(self, instance): 203 # NB: do NOT test the super implementation 204 return isinstance(instance, torch.Tensor) and getattr( 205 instance, "_is_flat_param", False 206 ) 207 208 209class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): 210 """ 211 This is the flat parameter used by :class:`FullyShardedDataParallel`. 212 213 It is comprised of one or more original parameters, which are flattened and 214 concatenated to construct the flat parameter. 215 216 Under the current design, this parameter logically represents both the 217 unsharded and sharded flat parameter, and its data changes storages 218 dynamically. 219 - In the :class:`FullyShardedDataParallel` constructor, the parameter 220 is initialized as unsharded and then sharded in-place. 221 - At runtime, the parameter is lazily (re)-initialized. The sharded 222 parameter data is saved in ``self._local_shard``, and a new ``Tensor`` 223 ``self._full_param_padded`` is created, which is the all-gather 224 destination and owns the unsharded parameter storage thereafter. (See 225 :meth:`FlatParamHandle.init_flat_param_attributes`.) 226 - Throughout runtime, the parameter data changes storages as needed, 227 e.g. to the sharded flat parameter, low precision sharded flat 228 parameter, or the unsharded flat parameter. 229 230 NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter`` 231 padding, we have two versions of the per-parameter numels, one that 232 includes the padding (``_numels_with_padding``) and one that does not 233 (``_numels``). The former may have length longer than the other data 234 structures, while the latter has the same length as the number of actual 235 original parameters like the other per-parameter data structures. 236 237 NOTE: This is not a real class; instead, you will always get a Parameter 238 back out if you try to create one of these. This is similar to the trick 239 we implemented for Parameter to get it to work with subclasses; this 240 is primarily so that FlatParameter supports combination with FakeTensor. 241 242 Attributes: 243 _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size 244 without right-hand-side padding for divisibility by the world size. 245 For ``use_orig_params=True``, this includes alignment padding. 246 _padded_unsharded_size (torch.Size): Unsharded flat parameter's size 247 with right-hand-side padding for divisibility by the world size. 248 For ``use_orig_params=True``, this includes alignment padding. This 249 is only set for sharded strategies since they require padding for 250 the all-gather. 251 _sharded_size (torch.Size): Sharded flat parameter's size with padding. 252 This is also set for ``NO_SHARD``, in which case it is the same as 253 the unsharded sizes. (We omit "padded" because there is no 254 analogous unpadded one.) 255 256 _num_params (int): Number of original parameters flattened into this 257 flat parameter. This is the length of the per-parameter data 258 structures. 259 _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info 260 entry; see :class:`ParamInfo` for details. 261 _shapes (Tuple[torch.Size, ...]): Each parameter's original shape. 262 _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN) 263 prefixed from the ``_fully_sharded_module``. The names are 264 guaranteed to be unique in the subtree rooted at that module. 265 _param_extensions (Tuple[Optional[Any], ...]): Each parameter's 266 extension (i.e. some per-parameter state) used to customize 267 pre-flatten and post-unflatten behavior or ``None``. This is 268 experimental, and users should not depend on its existence in the 269 future. 270 _numels_with_padding (Tuple[int, ...]): Each parameter's numel 271 including entries for the padding. This is used to construct views 272 into the flat parameter via ``torch.split()``. This may have length 273 longer than ``_num_params``. 274 _numels (Tuple[int, ...]): Each parameter's numel excluding entries for 275 padding. This has length equal to ``_num_params``. 276 _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's 277 shard parameter info; see :class:`_ShardParamInfo` for details. 278 _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter 279 info entries; see :class:`SharedParamInfo` for details. 280 _modules (Set[nn.Module]): Modules that contain some original parameter 281 that is flattened into the flat parameter. 282 283 _shard_numel_padded (int): Numel padded for this rank's sharded flat 284 parameter. 285 _local_shard (Tensor): Sharded flat parameter with padding if using a 286 sharded strategy. If using ``NO_SHARD``, then this is the unpadded 287 unsharded flat parameter, and there is no notion of a sharded flat 288 parameter or padded unsharded flat parameter. 289 _full_param_padded (Tensor): Unsharded flat parameter with padding. 290 This is not defined for ``NO_SHARD``. When using mixed precision 291 for parameters, this has the low precision. 292 _full_prec_full_param_padded (Tensor): Full precision unsharded flat 293 parameter with padding. This is used for unsharding outside of 294 computation when using mixed precision for parameters. This is 295 never defined for ``NO_SHARD``. 296 _post_backward_hook_handle (RemovableHandle): 297 Flat parameter's post-backward hook handle. (Compile only) 298 _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): 299 Flat parameter's :class:`AccumulateGrad` object and post-backward 300 hook handle. (Eager only) 301 _mp_shard (Tensor): Low precision sharded flat parameter with padding. 302 This is only defined when parameter mixed precision is enabled. For 303 ``NO_SHARD``, this is used for computation. 304 _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. 305 This is only defined when offloading parameters is enabled. 306 _saved_grad_shard (Tensor): Sharded gradient with padding from previous 307 iterations for gradient accumulation without :meth:`no_sync`. 308 309 _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``, 310 then each original parameter variable; otherwise, ``None``. This 311 does not include any padding tensors. 312 _shared_params (Optional[List[nn.Parameter]]): The original shared 313 parameter variables if ``use_orig_params=True`` and ``None`` 314 otherwise. 315 _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor`` 316 views created in the forward and tracked by autograd when 317 ``use_orig_params=True`` and is ``None`` otherwise. This is to 318 preserve those ``Tensor`` variables for the backward to ensure that 319 the ``FlatParameter`` 's ``AccumulateGrad`` object does not change 320 in which case the post-backward hook does not run. This is relevant 321 for cases like reentrant activation checkpointing. 322 _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``, 323 a mask over the original parameters' gradients indicating if it is 324 logically ``None`` or not; otherwise, ``None``. This does not 325 include entries for padding. This mask is needed because only some 326 of the parameters may have ``None`` gradient, in which case the 327 flat gradient must be non-``None`` and must use zeros to 328 approximate those original ``None`` gradients. This mask informs 329 FSDP to set the original parameter gradients to ``None`` (instead 330 of zeros) as needed. 331 """ 332 333 _unpadded_unsharded_size: torch.Size 334 _padded_unsharded_size: torch.Size 335 _sharded_size: torch.Size 336 _num_params: int 337 _param_infos: Tuple[ParamInfo, ...] 338 _shapes: Tuple[torch.Size, ...] 339 _fqns: Tuple[str, ...] 340 _param_extensions: Tuple[Optional[Any], ...] 341 _numels_with_padding: Tuple[int, ...] 342 _numels: Tuple[int, ...] 343 _shard_param_infos: Tuple[_ShardParamInfo, ...] 344 _shared_param_infos: Tuple[SharedParamInfo, ...] 345 _modules: Set[nn.Module] 346 _shard_numel_padded: int 347 _local_shard: Tensor 348 _full_param_padded: Tensor 349 _full_prec_full_param_padded: Tensor 350 # Eager only 351 _post_backward_hook_state: Tuple[Any, Any] 352 # Compile only 353 _post_backward_hook_handle: Any 354 _mp_shard: Tensor 355 _cpu_grad: Tensor 356 _saved_grad_shard: Tensor 357 _params: Optional[List[nn.Parameter]] 358 _shared_params: Optional[List[nn.Parameter]] 359 _tensors: Optional[List[Optional[Tensor]]] 360 _is_grad_none_mask: Optional[List[bool]] 361 362 _is_padding_mask: List[bool] 363 364 def __new__(cls, data=None, requires_grad=True): 365 assert cls is FlatParameter, "subclasses FlatParameter not supported" 366 r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg] 367 r._is_flat_param = True # type: ignore[attr-defined] 368 return r 369 370 # NB: This is not a regular method, because FlatParameters are not actually 371 # instances of this class (see __new__ above). So you must indirectly 372 # call this directly through the classmethod. 373 @classmethod 374 def _init_metadata( 375 cls, 376 self, 377 param_infos: List[ParamInfo], 378 numels: List[int], 379 shapes: List[torch.Size], 380 fqns: List[str], 381 shared_param_infos: List[SharedParamInfo], 382 param_extensions: List[Optional[Any]], 383 params: Optional[List[nn.Parameter]], 384 shared_params: Optional[List[nn.Parameter]], 385 is_padding_mask: List[bool], 386 ) -> None: 387 """ 388 Initialize attributes holding metadata about the original parameters comprising the flat parameter. 389 390 We expose this method separate from the constructor to keep the 391 constructor only responsible for the flat parameter's tensor data. This 392 method should only be called once per model, while the constructor may 393 be called multiple times, e.g. when reloading from a checkpoint, in 394 which case only the tensor data needs to be passed to the constructor. 395 Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the 396 metadata is correctly assumed to be unchanged. 397 398 Args: 399 See the Attributes in the class docstring. 400 """ 401 assert len(param_infos) == len(shapes) 402 assert len(param_infos) == len(fqns) 403 assert len(param_infos) == len(param_extensions) 404 self._num_params = len(param_infos) 405 self._param_infos = param_infos 406 self._shapes = shapes 407 self._fqns = fqns 408 self._param_extensions = param_extensions 409 self._is_padding_mask = is_padding_mask 410 411 numels_without_padding: List[int] = [] 412 for numel, is_padding in zip(numels, is_padding_mask): 413 if not is_padding: 414 numels_without_padding.append(numel) 415 self._numels = tuple(numels_without_padding) 416 self._numels_with_padding = tuple(numels) 417 assert len(self._numels) == self._num_params 418 419 self._shared_param_infos = tuple(shared_param_infos) 420 self._modules = {pi.module for pi in self._param_infos}.union( 421 {spi.module for spi in self._shared_param_infos} 422 ) 423 assert (params is None) == (shared_params is None) 424 if params is not None: 425 assert shared_params is not None and len(shared_params) == len( 426 shared_param_infos 427 ) 428 self._params = [] 429 for param, is_padding in zip(params, is_padding_mask): 430 if not is_padding: 431 self._params.append(param) 432 self._shared_params = shared_params 433 # Mark the original parameters to avoid flattening them into 434 # another `FlatParameter` during recursive construction 435 for param in chain(self._params, self._shared_params): 436 _set_fsdp_flattened(param) 437 self._is_grad_none_mask = [False for _ in range(self._num_params)] 438 self._tensors = [None for _ in range(self._num_params)] 439 else: 440 self._params = None 441 self._shared_params = None 442 self._is_grad_none_mask = None 443 self._tensors = None 444 self._unpadded_unsharded_size = self.size() 445 _set_fsdp_flattened(self) 446 # Tracks whether the `FlatParameter`'s post-backward hook has been 447 # called to modify the behavior of the post-backward callback 448 self._post_backward_called = False 449 450 451class FlatParamHandle: 452 """ 453 A handle that manages a flat parameter (:class:`FlatParameter`). 454 455 This includes sharding and view management. 456 457 Args: 458 params (Sequence[nn.Parameter]): The parameters to flatten into the 459 flat parameter. 460 fully_sharded_module (nn.Module): See [Note: Fully Sharded Module]. 461 device (torch.device): The compute and communication device, which 462 should be a non-CPU device. We refer to it as the compute device. 463 sharding_strategy (ShardingStrategy): Sharding strategy to apply to 464 this handle's ``FlatParameter``. 465 offload_params (bool): Whether to offload the handle's 466 ``FlatParameter`` to CPU. 467 mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision 468 setting passed to the FSDP constructor. 469 mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed 470 precision setting passed to the FSDP constructor. 471 keep_low_precision_grads (bool): Whether to keep gradients in low 472 precision. 473 use_orig_params (bool): If ``True``, then FSDP preserves the original 474 parameter variables and returns them from ``named_parameters()`` 475 (e.g. to support different optimizer hyperparameters within one 476 :class:`FlatParameter`). If ``False``, then FSDP reconstructs the 477 parameters every iteration and returns the :class:`FlatParameter` s 478 from ``named_parameters()``. 479 """ 480 481 ################## 482 # INITIALIZATION # 483 ################## 484 def __init__( 485 self, 486 params: Sequence[Union[nn.Parameter, Tensor]], 487 fully_sharded_module: nn.Module, 488 device: torch.device, 489 sharding_strategy: HandleShardingStrategy, 490 offload_params: bool, 491 mp_param_dtype: Optional[torch.dtype], 492 mp_reduce_dtype: Optional[torch.dtype], 493 keep_low_precision_grads: bool, 494 process_group: dist.ProcessGroup, 495 use_orig_params: bool, 496 *, 497 fsdp_extension: Optional[FSDPExtensions] = None, 498 ): 499 super().__init__() 500 params = list(params) 501 if len(params) == 0: 502 raise ValueError( 503 f"Cannot construct a {self.__class__.__name__} with an empty parameter list" 504 ) 505 self._init_setattr_fns() 506 self._skip_writeback_check = ( 507 os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1" 508 ) 509 self._use_full_prec_in_eval = ( 510 os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" 511 ) 512 self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1" 513 self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1" 514 if self._skip_writeback_check: 515 _warn_skip_writeback_check( 516 logger, 517 f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check " 518 "for parameter or gradient writeback. Changing parameter or " 519 "gradient storages may lead to silent correctness errors.", 520 ) 521 if self._use_fake_all_gather: 522 _warn_use_fake_all_gather( 523 logger, 524 f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute " 525 "all-gather ops. Your training will be incorrect, but " 526 "can reveal how much time spent on all-gather ops.", 527 ) 528 if self._use_fake_reduce: 529 _warn_use_fake_reduce( 530 logger, 531 f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute " 532 "reduce-scatter ops. Your training will be incorrect, but " 533 "can reveal how much time spent on reduce-scatter ops.", 534 ) 535 # Only align addresses for `use_orig_params=True` (for now) 536 align_addresses = use_orig_params 537 self._init_get_unflat_views_fn(align_addresses) 538 self.device = device 539 self._device_handle = _FSDPDeviceHandle.from_device(self.device) 540 self.process_group = process_group 541 if self._use_fake_all_gather or self._use_fake_reduce: 542 self._fake_process_group = FakeProcessGroup( 543 rank=process_group.rank(), world_size=process_group.size() 544 ) 545 self.rank = process_group.rank() 546 self.world_size = process_group.size() 547 self._sharding_strategy = sharding_strategy 548 self._offload_params = offload_params 549 self._use_orig_params = use_orig_params 550 self._keep_low_precision_grads = keep_low_precision_grads 551 self._training_state = HandleTrainingState.IDLE 552 self._debug_level = dist.get_debug_level() 553 self._fully_sharded_module = fully_sharded_module 554 # For strategies that do not free after forward, we skip using sharded 555 # views after forward since the unsharded data exists. We still switch 556 # `self.flat_param` to point to the sharded flat parameter since what 557 # it points to parameterizes behavior. We use the following attribute 558 # to track which tensor data the parameters are unsharded views into. 559 self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None 560 # The index in the state's `all_handles`, which must be the 561 # same across ranks for the execution order validation to work 562 self._handle_index: Optional[int] = None 563 # Index in handles_to_pre_forward_order 564 self._pre_forward_order_index: Optional[int] = None 565 # Index in `handles_post_forward_order` 566 self._post_forward_index: Optional[int] = None 567 # Used for guarding against mistargeted forward prefetches 568 self._needs_pre_forward_unshard = False 569 # Used for guarding against mistargeted backward prefetches 570 self._needs_pre_backward_unshard = False 571 # Was the handle prefetched? Set on successful _prefetch_handle and unshard 572 self._prefetched = False 573 # Optimistically assume a valid input `params` and set dtype attributes 574 # before `_init_flat_param()`, which performs the actual validation 575 self._orig_param_dtype = params[0].dtype 576 self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype) 577 assert self._fwd_bwd_param_dtype is not None # mypy 578 self._aligned_numel = ( 579 _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype) 580 if align_addresses 581 else 0 582 ) 583 self._fsdp_extension = fsdp_extension 584 self._init_flat_param_and_metadata( 585 params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type] 586 ) 587 self._use_unsharded_views(as_params=False) 588 589 def _init_setattr_fns(self): 590 use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1" 591 self._setattr_tensor: Callable[[nn.Module, str, Tensor], None] 592 self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None] 593 if use_unsafe_setattr: 594 self._setattr_tensor = _unsafe_setattr_tensor 595 self._setattr_param = _unsafe_setattr_param 596 else: 597 self._setattr_tensor = _safe_setattr_tensor_or_param 598 self._setattr_param = _safe_setattr_tensor_or_param 599 600 def _init_get_unflat_views_fn(self, align_addresses: bool): 601 self._get_unflat_views = ( 602 self._get_unflat_views_aligned 603 if align_addresses 604 else self._get_unflat_views_unaligned 605 ) 606 607 def _init_flat_param_and_metadata( 608 self, 609 params: List[Union[Tensor, nn.Parameter]], 610 module: nn.Module, 611 aligned_numel: int, 612 use_orig_params: bool, 613 ) -> None: 614 """ 615 Initialize the ``FlatParameter`` and its metadata. 616 617 NOTE: This should only be called once at construction time, after which 618 the ``FlatParameter`` metadata is assumed to be static. 619 620 NOTE: The elements of ``params`` should only be ``Tensor`` s when 621 composing with ``DTensor`` -based tensor parallelism, in which case the 622 elements may be ``DTensor`` local shards. 623 """ 624 if len(params) == 0: 625 raise ValueError("Expects non-empty `params`") 626 if aligned_numel < 0: 627 raise ValueError( 628 f"Expects non-negative `aligned_numel` but got {aligned_numel}" 629 ) 630 ( 631 dtype, 632 flat_param_requires_grad, 633 device, 634 ) = self._validate_tensors_to_flatten(params) 635 params_set = set(params) 636 # For alignment padding, only `numels` gets strictly non-`None` 637 # elements, and all other lists get `None` elements for padding. 638 param_infos: List[ParamInfo] = [] 639 numels: List[int] = [] 640 shapes: List[torch.Size] = [] 641 fqns: List[str] = [] 642 shared_param_infos: List[SharedParamInfo] = [] 643 shared_param_memo: Dict[ 644 Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str] 645 ] = {} 646 params_to_flatten: List[Union[Tensor, nn.Parameter]] = [] 647 shared_params: List[Union[Tensor, nn.Parameter]] = [] 648 param_extensions: List[Any] = [] 649 is_padding_mask: List[bool] = [] 650 total_numel = total_numel_without_padding = 0 651 for submodule_name, submodule in module.named_modules(remove_duplicate=False): 652 for param_name, param in _named_parameters_with_duplicates( 653 submodule, recurse=False 654 ): 655 if param not in params_set: 656 continue 657 if param in shared_param_memo: # shared reference 658 prim_module, prim_module_name, prim_param_name = shared_param_memo[ 659 param 660 ] 661 shared_params.append(param) 662 shared_param_infos.append( 663 SharedParamInfo( 664 param_name, 665 submodule, 666 submodule_name, 667 prim_param_name, 668 prim_module, 669 prim_module_name, 670 ) 671 ) 672 else: 673 if aligned_numel > 0: 674 numel_to_pad = aligned_numel - (total_numel % aligned_numel) 675 if numel_to_pad > 0 and numel_to_pad < aligned_numel: 676 padding_tensor = _construct_padding_tensor( 677 numel_to_pad, dtype, False, device 678 ) 679 params_to_flatten.append(padding_tensor) 680 is_padding_mask.append(True) 681 numels.append(numel_to_pad) 682 total_numel += numel_to_pad 683 transform_t, extension = _ext_pre_flatten_transform( 684 param, 685 self._fsdp_extension, 686 ) 687 param = cast(nn.Parameter, transform_t) 688 param_extensions.append(extension) 689 shared_param_memo[param] = (submodule, submodule_name, param_name) 690 params_to_flatten.append(param) 691 is_padding_mask.append(False) 692 param_infos.append(ParamInfo(param_name, submodule, submodule_name)) 693 numels.append(param.numel()) 694 shapes.append(param.shape) 695 fqn = ( 696 submodule_name + "." + param_name 697 if submodule_name 698 else param_name 699 ) 700 fqns.append(fqn) 701 total_numel += param.numel() 702 total_numel_without_padding += param.numel() 703 if len(params_to_flatten) == 0: 704 raise ValueError( 705 f"`params` were not found in `module`'s tree" 706 f"params: {params}\nmodule: {module}" 707 ) 708 if ( 709 self.rank == 0 710 and aligned_numel > 0 711 and total_numel != total_numel_without_padding 712 ): 713 logger.debug( 714 "FSDP FlatParameter address alignment created " 715 "%s numel of padding (%s vs. %s)", 716 total_numel - total_numel_without_padding, 717 total_numel, 718 total_numel_without_padding, 719 ) 720 if aligned_numel > 0: 721 # Pad to be divisible by world size to avoid a copy for the 722 # post-backward reduce-scatter 723 numel_to_pad = self.world_size - (total_numel % self.world_size) 724 if numel_to_pad > 0 and numel_to_pad < self.world_size: 725 if self.rank == 0: 726 logger.info( 727 "FSDP FlatParameter world size divisibility created " 728 "%s numel of padding", 729 numel_to_pad, 730 ) 731 padding_tensor = _construct_padding_tensor( 732 numel_to_pad, dtype, False, device 733 ) 734 params_to_flatten.append(padding_tensor) 735 is_padding_mask.append(True) 736 numels.append(numel_to_pad) 737 total_numel += numel_to_pad 738 # Pass `aligned_numel=0` since we already included padding tensors 739 self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param( 740 params_to_flatten, 741 aligned_numel=0, 742 requires_grad=flat_param_requires_grad, 743 ) 744 FlatParameter._init_metadata( 745 self.flat_param, 746 param_infos, 747 numels, 748 shapes, 749 fqns, 750 shared_param_infos, 751 param_extensions, 752 _convert_to_params(params_to_flatten) if use_orig_params else None, 753 _convert_to_params(shared_params) if use_orig_params else None, 754 is_padding_mask, 755 ) 756 757 def _validate_tensors_to_flatten( 758 self, tensors: List[Union[Tensor, nn.Parameter]] 759 ) -> Tuple: 760 """Validate the tensors to flatten and returns any necessary metadata.""" 761 dtype: Optional[torch.dtype] = None 762 # Return as the logical OR over each tensor's value 763 flat_param_requires_grad: Optional[bool] = None 764 device: Optional[torch.device] = None 765 # For `use_orig_params=True`, permit non-uniform `requires_grad` 766 for tensor in tensors: 767 if isinstance(tensor, FlatParameter): 768 raise ValueError("Cannot flatten a `FlatParameter`") 769 if dtype is None and not tensor.is_floating_point(): 770 raise ValueError("Cannot flatten integer dtype tensors") 771 if dtype is not None and tensor.dtype != dtype: 772 raise ValueError( 773 f"Must flatten tensors with uniform dtype but got {dtype} " 774 f"and {tensor.dtype}" 775 ) 776 if ( 777 not self._use_orig_params 778 and flat_param_requires_grad is not None 779 and tensor.requires_grad != flat_param_requires_grad 780 ): 781 raise ValueError( 782 "Must flatten tensors with uniform `requires_grad` when " 783 "`use_orig_params=False`" 784 ) 785 if device is not None and tensor.device != device: 786 raise ValueError( 787 "Must flatten tensors on the same device but got both " 788 f"{device} and {tensor.device}" 789 ) 790 dtype = tensor.dtype 791 flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad 792 device = tensor.device 793 assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list" 794 return dtype, flat_param_requires_grad, device 795 796 def flatten_tensors( 797 self, 798 tensors: List[Tensor], 799 aligned_numel: int, 800 ) -> Tensor: 801 """ 802 Flatten ``tensors`` into a single flat tensor. 803 804 The flattening optionally includes 805 padding if ``aligned_numel`` is greater than 0, where ``aligned_numel`` 806 gives the numel required to have address alignment. 807 808 NOTE: The padding alignment algorithm must be kept in sync with 809 :meth:`_init_flat_param_metadata`. We separate the two methods because 810 the initialization happens once, whereas this method may be called 811 multiple times throughout training (e.g. for checkpointing). 812 """ 813 if len(tensors) == 0: 814 raise ValueError("Expects non-empty `tensors`") 815 if aligned_numel < 0: 816 raise ValueError( 817 f"Expects non-negative `aligned_numel` but got {aligned_numel}" 818 ) 819 dtype, _, device = self._validate_tensors_to_flatten(tensors) 820 flat_tensors: List[Tensor] = [] 821 if aligned_numel > 0: 822 total_numel = 0 823 for tensor in tensors: 824 numel_to_pad = aligned_numel - (total_numel % aligned_numel) 825 if numel_to_pad > 0 and numel_to_pad < aligned_numel: 826 padding_tensor = _construct_padding_tensor( 827 numel_to_pad, dtype, False, device 828 ) 829 flat_tensors.append(padding_tensor) 830 total_numel += numel_to_pad 831 flat_tensors.append(torch.flatten(_detach_if_needed(tensor))) 832 total_numel += tensor.numel() 833 numel_to_pad = self.world_size - (total_numel % self.world_size) 834 if numel_to_pad > 0 and numel_to_pad < self.world_size: 835 padding_tensor = _construct_padding_tensor( 836 numel_to_pad, dtype, False, device 837 ) 838 flat_tensors.append(padding_tensor) 839 total_numel += numel_to_pad 840 else: 841 flat_tensors = [ 842 torch.flatten(_detach_if_needed(tensor)) for tensor in tensors 843 ] 844 return torch.cat(flat_tensors, dim=0) 845 846 def flatten_tensors_into_flat_param( 847 self, 848 tensors: List[Tensor], 849 aligned_numel: int, 850 requires_grad: bool, 851 ) -> FlatParameter: 852 flat_param_data = self.flatten_tensors(tensors, aligned_numel) 853 return FlatParameter(flat_param_data, requires_grad=requires_grad) 854 855 def _init_param_reduce_dtypes( 856 self, 857 mp_param_dtype: Optional[torch.dtype], 858 mp_reduce_dtype: Optional[torch.dtype], 859 ) -> None: 860 """ 861 Initialize param and reduce dtypes. 862 863 Precondition: ``self.flat_param`` is set. This ensures that this 864 handle's parameters have a single dtype. 865 866 Postcondition: This sets ``self._fwd_bwd_param_dtype`` and 867 ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype`` 868 is ``None``, then we assume the original parameter dtype. One special 869 case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype`` 870 is ``None``, in which case we assume the gradient reduction dtype 871 matches the forward/backward parameter dtype. 872 """ 873 # Save whether these dtypes were specified so that we permit the 874 # parameter dtype to change up until the lazy initialization 875 self._low_prec_param_dtype_specified = mp_param_dtype is not None 876 self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None 877 if ( 878 self._low_prec_param_dtype_specified 879 and not self._low_prec_reduce_dtype_specified 880 ): 881 # Special case: infer gradient reduction mixed precision 882 self._fwd_bwd_param_dtype = mp_param_dtype 883 self._reduce_dtype = self._fwd_bwd_param_dtype 884 else: 885 self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype 886 self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype 887 assert self._fwd_bwd_param_dtype is not None 888 assert self._reduce_dtype is not None 889 890 ################################### 891 # SHARD INITIALIZATION & METADATA # 892 ################################### 893 @torch.no_grad() 894 def shard(self): 895 """ 896 Shard the handle's ``FlatParameter``. 897 898 This allocates new memory for 899 the sharded flat parameter and frees the unsharded flat parameter's 900 storage. 901 902 Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard 903 metadata attributes are set for all sharding strategies. 904 """ 905 flat_param = self.flat_param 906 if not self.uses_sharded_strategy: 907 self._init_shard_metadata(0, 0, flat_param.numel() - 1) 908 else: 909 _p_assert( 910 flat_param.storage_offset() == 0, 911 "The `FlatParameter` is not the sole occupant of its storage", 912 ) 913 sharded_flat_param, numel_padded = FlatParamHandle._get_shard( 914 flat_param, self.rank, self.world_size 915 ) 916 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 917 allocated = flat_param._typed_storage()._size() > 0 918 if allocated: 919 flat_param._typed_storage()._resize_(0) 920 flat_param.set_(sharded_flat_param) # type: ignore[call-overload] 921 start_idx = sharded_flat_param.numel() * self.rank 922 end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive 923 self._init_shard_metadata(numel_padded, start_idx, end_idx) 924 if self._use_orig_params: 925 self._use_sharded_views() 926 927 def _init_shard_metadata( 928 self, 929 numel_padded: int, 930 unsharded_start_idx: int, 931 unsharded_end_idx: int, 932 ) -> None: 933 """ 934 Initialize shard-related metadata for this rank's shard of the flat parameter. 935 936 This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``. 937 938 Args: 939 numel_padded (int): Numel padded for this rank's sharded flat 940 parameter. 941 unsharded_start_idx (int): Start index in the unsharded flat 942 parameter assigned to this rank. 943 unsharded_end_idx (int): End index (inclusive) in the unsharded 944 flat parameter assigned to this rank. 945 946 Precondition: ``self.flat_param`` 's data is the sharded flat 947 parameter. 948 """ 949 flat_param = self.flat_param 950 flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined] 951 sharded_flat_param_numel = flat_param.numel() # includes `numel_padded` 952 _p_assert( 953 unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx, 954 f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}", 955 ) 956 _p_assert( 957 numel_padded <= sharded_flat_param_numel, 958 f"numel_padded: {numel_padded} " 959 f"sharded_flat_param_numel: {sharded_flat_param_numel}", 960 ) 961 shard_param_infos = self._get_shard_metadata( 962 unsharded_start_idx, unsharded_end_idx 963 ) 964 assert ( 965 len(shard_param_infos) == flat_param._num_params 966 ), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" 967 flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] 968 flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] 969 970 def _get_shard_metadata( 971 self, 972 unsharded_start_idx: int, 973 unsharded_end_idx: int, 974 ) -> Tuple[_ShardParamInfo, ...]: 975 """ 976 Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive). 977 978 ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the 979 unsharded flat parameter specifying the shard. 980 """ 981 flat_param_offsets = self._get_flat_param_offsets() 982 assert len(flat_param_offsets) == len( 983 self.flat_param._numels_with_padding 984 ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" 985 shard_param_infos: List[_ShardParamInfo] = [] 986 sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 987 # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices 988 # into the unsharded flat parameter (inclusive) of the given parameter 989 for i, ( 990 (unsharded_param_start_idx, unsharded_param_end_idx), 991 is_padding, 992 ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)): 993 if is_padding: 994 continue 995 in_sharded_flat_param = ( 996 unsharded_start_idx <= unsharded_param_end_idx 997 and unsharded_end_idx >= unsharded_param_start_idx 998 ) 999 if not in_sharded_flat_param: 1000 shard_param_info = _ShardParamInfo(False, None, None, None, None) 1001 else: 1002 if unsharded_start_idx <= unsharded_param_start_idx: 1003 # This branch can only happen once since the rank's 1004 # unsharded start index can only intersect one parameter 1005 intra_param_start_idx = 0 1006 offset_in_shard = unsharded_param_start_idx - unsharded_start_idx 1007 else: 1008 intra_param_start_idx = ( 1009 unsharded_start_idx - unsharded_param_start_idx 1010 ) 1011 offset_in_shard = 0 1012 assert ( 1013 offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel 1014 ), ( 1015 f"Invalid `offset_in_shard` of {offset_in_shard} for " 1016 f"sharded flat parameter with {sharded_flat_param_numel} numel" 1017 ) 1018 intra_param_end_idx = ( 1019 min(unsharded_param_end_idx, unsharded_end_idx) 1020 - unsharded_param_start_idx 1021 ) 1022 numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1 1023 shard_param_info = _ShardParamInfo( 1024 True, 1025 offset_in_shard, 1026 numel_in_shard, 1027 intra_param_start_idx, 1028 intra_param_end_idx, 1029 ) 1030 shard_param_infos.append(shard_param_info) 1031 return tuple(shard_param_infos) 1032 1033 @staticmethod 1034 def _get_unpadded_shard( 1035 tensor: Tensor, 1036 rank: int, 1037 world_size: int, 1038 ) -> Tuple[Tensor, int]: 1039 """ 1040 Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``. 1041 1042 The returned value is a tuple of the shard of ``tensor`` without any 1043 padding and the numel to pad for that shard. 1044 1045 If ``tensor`` is already flattened or may be viewed in the flattened 1046 shape (which is true in the expected usage), then this method does not 1047 allocate any new tensor memory. 1048 """ 1049 chunks = torch.flatten(tensor).chunk(world_size) 1050 if len(chunks) < (rank + 1): 1051 # This rank gets an empty chunk fully padded with zeros since there 1052 # are not enough chunks across ranks 1053 chunk = chunks[0].new_empty(0) 1054 else: 1055 chunk = chunks[rank] 1056 numel_to_pad = chunks[0].numel() - chunk.numel() 1057 assert ( 1058 numel_to_pad >= 0 1059 ), "Chunk's size should be at most the first chunk's size" 1060 return chunk, numel_to_pad 1061 1062 @staticmethod 1063 def _get_shard( 1064 tensor: Tensor, 1065 rank: int, 1066 world_size: int, 1067 ) -> Tuple[Tensor, int]: 1068 """ 1069 Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard. 1070 1071 This method allocates new memory (via :meth:`clone`) since the 1072 unsharded ``tensor`` may be deallocated after this method returns. 1073 """ 1074 chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( 1075 tensor, rank, world_size 1076 ) 1077 shard = chunk.clone() 1078 if numel_to_pad > 0: 1079 shard = F.pad(shard, [0, numel_to_pad]) 1080 return shard, numel_to_pad 1081 1082 @staticmethod 1083 def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: 1084 """ 1085 Return the shape of ``tensor`` after sharding including padding. 1086 1087 This requires ``tensor`` to have 1D shape and ensures that the returned 1088 shape is 1D. 1089 """ 1090 assert len(tensor.shape) == 1, f"{tensor.shape}" 1091 unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( 1092 tensor, rank, world_size 1093 ) 1094 unpadded_sharded_size = unpadded_sharded_tensor.size() 1095 assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" 1096 return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) 1097 1098 def _get_flat_param_offsets(self) -> List[Tuple[int, int]]: 1099 """ 1100 Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding). 1101 1102 NOTE: The returned list includes elements for alignment padding. 1103 """ 1104 cumulative_sum = list(accumulate(self.flat_param._numels_with_padding)) 1105 starts = [0] + cumulative_sum[:-1] 1106 ends = [end - 1 for end in cumulative_sum] # inclusive 1107 param_offsets = list(zip(starts, ends)) 1108 return param_offsets 1109 1110 @no_type_check 1111 def shard_metadata( 1112 self, 1113 ) -> FlatParamShardMetadata: 1114 """ 1115 Return the shard-related metadata specific to this rank's shard of the flat parameter. 1116 1117 NOTE: The returned tuple does not include elements for alignment 1118 padding but does account for the padding. 1119 """ 1120 fqns_list = [] 1121 shapes_list = [] 1122 numels_list = [] 1123 shard_param_offsets = [] 1124 for fqn, shape, numel, shard_param_info in zip( 1125 self.flat_param._fqns, 1126 self.flat_param._shapes, 1127 self.flat_param._numels, 1128 self.flat_param._shard_param_infos, 1129 ): 1130 if not shard_param_info.in_shard: 1131 continue 1132 fqns_list.append(fqn) 1133 shapes_list.append(shape) 1134 numels_list.append(numel) 1135 shard_param_offsets.append( 1136 ( 1137 shard_param_info.intra_param_start_idx, 1138 shard_param_info.intra_param_end_idx, 1139 ) 1140 ) 1141 return FlatParamShardMetadata( 1142 tuple(fqns_list), 1143 tuple(shapes_list), 1144 tuple(numels_list), 1145 tuple(shard_param_offsets), 1146 ) 1147 1148 @no_type_check 1149 @torch.no_grad() 1150 def init_flat_param_attributes(self) -> None: 1151 """ 1152 This initializes some attributes on the handle's ``FlatParameter``. 1153 This should be called during lazy initialization since it requires the 1154 parameter to be on the compute device if not offloading to CPU and we 1155 want to give users the chance to move the parameter appropriately after 1156 the FSDP constructor. 1157 1158 For each tensor attribute on the ``FlatParameter``, see the unshard and 1159 reshard methods in this class for the allocation and free pattern. 1160 """ 1161 flat_param = self.flat_param 1162 if flat_param.dtype != self._orig_param_dtype: 1163 # Entering this branch means that the user changed the parameter 1164 # dtype after FSDP initialization, in which case we may need to 1165 # refresh some saved dtype attributes (dtypes specified as a part 1166 # of mixed precision take precedence). 1167 if not self._low_prec_param_dtype_specified: 1168 self._fwd_bwd_param_dtype = flat_param.dtype 1169 # For `reduce_dtype`, require `param_dtype` was not specified since 1170 # then we infer the `reduce_dtype` from the specified `param_dtype` 1171 if ( 1172 not self._low_prec_reduce_dtype_specified 1173 and not self._low_prec_param_dtype_specified 1174 ): 1175 self._reduce_dtype = flat_param.dtype 1176 self._orig_param_dtype = flat_param.dtype 1177 cpu_device = torch.device("cpu") 1178 if self._offload_params: 1179 _p_assert( 1180 flat_param.device == cpu_device, 1181 f"Expects the `FlatParameter` to be on CPU when parameter CPU " 1182 f"offloading is enabled, not {flat_param.device}", 1183 ) 1184 else: 1185 self._check_on_compute_device(self.flat_param) 1186 flat_param._local_shard = flat_param.data 1187 if self._offload_params: 1188 # Pin the memory for faster H2D transfer 1189 flat_param._local_shard = flat_param._local_shard.pin_memory( 1190 device=self.device 1191 ) 1192 # Pre-allocate the sharded gradient on CPU to enable non-blocking 1193 # D2H transfer during the backward pass 1194 flat_param._cpu_grad = torch.zeros_like( 1195 flat_param._local_shard, device=cpu_device 1196 ).pin_memory(device=self.device) 1197 if self._uses_param_mixed_precision: 1198 # For parameter mixed precision, we maintain a low precision 1199 # sharded tensor on the compute device to be all-gathered (for 1200 # sharded strategies) or directly used (for `NO_SHARD`) for 1201 # computation. 1202 flat_param._mp_shard = torch.empty_like( 1203 flat_param._local_shard, 1204 device=self.device, 1205 dtype=self._fwd_bwd_param_dtype, 1206 ) 1207 _free_storage(flat_param._mp_shard) 1208 if self.uses_sharded_strategy: 1209 # We maintain a padded unsharded tensor that serves as the 1210 # all-gather destination and owns the original parameter storages. 1211 unsharded_param_dtype = ( 1212 self._fwd_bwd_param_dtype 1213 if self._uses_param_mixed_precision 1214 else flat_param.dtype 1215 ) # use low precision if parameter mixed precision is enabled 1216 padded_unsharded_numel = flat_param.numel() * self.world_size 1217 flat_param._full_param_padded = torch.empty( 1218 padded_unsharded_numel, 1219 device=self.device, 1220 dtype=unsharded_param_dtype, 1221 ) 1222 flat_param._padded_unsharded_size = flat_param._full_param_padded.size() 1223 _free_storage(flat_param._full_param_padded) 1224 1225 if self._uses_param_mixed_precision: 1226 # For parameter mixed precision, we maintain a full precision 1227 # padded unsharded tensor for when we force full precision. 1228 flat_param._full_prec_full_param_padded = torch.empty( 1229 padded_unsharded_numel, 1230 device=self.device, 1231 dtype=flat_param.dtype, # full precision 1232 ) 1233 _free_storage(flat_param._full_prec_full_param_padded) 1234 1235 ################### 1236 # UNSHARD/RESHARD # 1237 ################### 1238 def pre_unshard(self) -> bool: 1239 """ 1240 Return ``False`` if this is a no-op and ``True`` otherwise. 1241 1242 Postcondition: ``self.flat_param`` 's data is on the device for 1243 communication and is what should be all-gathered. This means that it 1244 matches the dtype of the expected unsharded parameter. 1245 """ 1246 if ( 1247 self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS 1248 and self._skipped_use_sharded_views 1249 ): 1250 # Since this path imposes special semantics for the unsharded flat 1251 # parameter (e.g. forcing full precision), use sharded views to 1252 # reuse the existing logic for that special handling 1253 self._use_sharded_views() 1254 ret = False 1255 if self._use_orig_params and not self._skip_writeback_check: 1256 ret = self._writeback_orig_params() 1257 if ( 1258 self.uses_sharded_strategy 1259 and not self._offload_params 1260 and not self.needs_unshard() 1261 ): 1262 pass # no-op 1263 elif self._uses_param_mixed_precision and not self._force_full_precision: 1264 self._use_low_precision_shard() 1265 ret = True 1266 elif self._offload_params and self.flat_param.device != self.device: 1267 # NOTE: This creates a new tensor distinct from any attributes. 1268 self.flat_param_to(self.device, non_blocking=True) 1269 ret = True 1270 self._check_on_compute_device(self.flat_param) 1271 return ret 1272 1273 def _use_low_precision_shard(self): 1274 """Allocate on the compute device and switch to using the low precision sharded flat parameter.""" 1275 self._check_low_precision_shard() 1276 flat_param = self.flat_param 1277 _alloc_storage( 1278 flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined] 1279 ) 1280 # `copy_()` implicitly casts to the low precision 1281 flat_param._mp_shard.copy_( # type: ignore[attr-defined] 1282 flat_param._local_shard.to( # type: ignore[attr-defined] 1283 self.device, non_blocking=True 1284 ) 1285 ) 1286 # Invariant: `_mp_shard` is always on the compute device. 1287 flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] 1288 1289 def unshard(self): 1290 """ 1291 Run the unshard logic. 1292 1293 This includes all-gathering the flat parameter 1294 and switching to using the unsharded flat parameter. If the handle does 1295 not need unsharding, then this only switches to using the unsharded 1296 flat parameter. For ``NO_SHARD``, this is a no-op. 1297 1298 If FSDP is in :meth:`summon_full_params` and the handle uses parameter 1299 mixed precision, then the parameter is forced to full precision. 1300 """ 1301 if not self.needs_unshard(): 1302 # Even when not needing an unshard, we should switch to using 1303 # the unsharded flat parameter 1304 unsharded_flat_param = ( 1305 self._get_padded_unsharded_flat_param() 1306 if self.uses_sharded_strategy 1307 else self.flat_param 1308 ) 1309 self._use_unsharded_flat_param(unsharded_flat_param) 1310 return 1311 unsharded_flat_param = self._alloc_padded_unsharded_flat_param() 1312 padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) 1313 self._use_unsharded_flat_param(padded_unsharded_flat_param) 1314 1315 def needs_unshard(self) -> bool: 1316 """Return if the handle's flat parameter needs to be unsharded.""" 1317 if not self.uses_sharded_strategy: 1318 return False 1319 unsharded_flat_param = self._get_padded_unsharded_flat_param() 1320 already_unsharded = _same_storage_size( 1321 unsharded_flat_param, unsharded_flat_param.numel() 1322 ) 1323 return not already_unsharded 1324 1325 def _alloc_padded_unsharded_flat_param(self): 1326 """ 1327 Allocate the *padded* unsharded flat parameter. 1328 1329 The unpadded unsharded 1330 flat parameter is always a view into the padded one. This padded 1331 parameter is saved to a different attribute on the ``FlatParameter`` 1332 depending on if we force full precision. 1333 """ 1334 self._check_sharded_strategy() 1335 flat_param = self.flat_param 1336 unsharded_flat_param = self._get_padded_unsharded_flat_param() 1337 self._check_storage_freed(unsharded_flat_param) 1338 _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] 1339 return unsharded_flat_param 1340 1341 def _get_padded_unsharded_flat_param(self) -> torch.Tensor: 1342 """ 1343 Return a reference to the padded unsharded flat parameter depending on the calling context. 1344 1345 This should only be called if using a sharded strategy. 1346 """ 1347 self._check_sharded_strategy() 1348 flat_param = self.flat_param 1349 if self._force_full_precision and self._uses_param_mixed_precision: 1350 # When parameter mixed precision is enabled, we use a different 1351 # tensor as the all-gather destination to preserve the invariant 1352 # that `_full_param_padded` is in the low precision 1353 unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] 1354 _p_assert( 1355 unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, 1356 f"Expects full precision but got {self._fwd_bwd_param_dtype}", 1357 ) 1358 # For no-reshard-after-forward strategies, `_full_param_padded` may 1359 # still be allocated from a previous forward. As we are forcing 1360 # full precision here, the full-precision unsharded copy may be 1361 # modified, invalidating the existing low-precision unsharded copy, 1362 # so we should free it here to ensure a new all-gather for the next 1363 # forward/backward computation to persist the modifications. 1364 if flat_param._full_param_padded.untyped_storage().size() > 0: 1365 _free_storage(flat_param._full_param_padded) 1366 else: 1367 unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] 1368 return unsharded_flat_param 1369 1370 def _all_gather_flat_param( 1371 self, 1372 padded_unsharded_flat_param: Tensor, 1373 ) -> Tensor: 1374 """ 1375 All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. 1376 1377 Then switch to use the all-gathered tensor. 1378 """ 1379 _p_assert( 1380 hasattr(self, "process_group") and hasattr(self, "world_size"), 1381 "Expects a process group and world size to have been set via `shard()`", 1382 ) 1383 sharded_flat_param = self.flat_param.data 1384 expected_numel = sharded_flat_param.numel() * self.world_size 1385 _p_assert( 1386 padded_unsharded_flat_param.numel() == expected_numel, 1387 f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", 1388 ) 1389 1390 pg = ( 1391 self._fake_process_group 1392 if self._use_fake_all_gather 1393 else self.process_group 1394 ) 1395 1396 # HACK this should be handled by C10D 1397 if sharded_flat_param.is_cpu: # type: ignore[attr-defined] 1398 tensor_list = list( 1399 torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg)) 1400 ) 1401 dist.all_gather(tensor_list, sharded_flat_param, group=pg) 1402 else: 1403 dist.all_gather_into_tensor( 1404 padded_unsharded_flat_param, 1405 sharded_flat_param, 1406 pg, 1407 ) 1408 1409 if self._offload_params: 1410 # In case of offloading, `flat_param.data` (i.e. sharded param) is 1411 # created on the pre-unshard stream. We need to hand it over to the 1412 # unshard stream for all-gather 1413 _no_dispatch_record_stream( 1414 sharded_flat_param, 1415 self._device_handle.current_stream(), # unshard_stream 1416 ) 1417 return padded_unsharded_flat_param 1418 1419 def _use_unsharded_flat_param( 1420 self, 1421 padded_unsharded_flat_param: torch.Tensor, 1422 ) -> None: 1423 """ 1424 Switch to use the *unpadded* unsharded flat parameter. 1425 1426 This is a view into the *padded* unsharded flat parameter. 1427 """ 1428 unsharded_size = self.flat_param._unpadded_unsharded_size 1429 flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()] 1430 # slicing [:] is not visible to autograd because of .data 1431 self.flat_param.data = flat_param_part 1432 in_forward = self._training_state == HandleTrainingState.FORWARD 1433 in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE 1434 if self._use_orig_params: 1435 if self._skipped_use_sharded_views and in_pre_backward: 1436 # This call corresponds to the complementary pre-backward 1437 # `_use_unsharded_views()` to the skipped pre-forward 1438 # `_use_sharded_views()`, so we should skip this one too. 1439 return 1440 # We use `Tensor` views in the forward so that they are tracked by 1441 # autograd. We use them in the pre-backward as well to support 1442 # reentrant activation checkpointing, which needs the views to be 1443 # tracked by autograd in the backward pass's recomputed forward. 1444 self._use_unsharded_views( 1445 as_params=(not in_forward and not in_pre_backward) 1446 ) 1447 elif in_forward: 1448 self._use_unsharded_views(as_params=False) 1449 1450 def post_unshard(self): 1451 """ 1452 Run the post-unshard logic. 1453 1454 This includes freeing the low precision shard if needed. 1455 """ 1456 if self._uses_param_mixed_precision and self.uses_sharded_strategy: 1457 self._free_low_precision_sharded_param() 1458 self._check_on_compute_device(self.flat_param) 1459 1460 def _free_low_precision_sharded_param(self): 1461 """Frees the low precision sharded flat parameter.""" 1462 self._check_low_precision_shard() 1463 # `_mp_shard` is allocated in the pre-unshard stream, consumed in the 1464 # unshard stream for sharded strategies, and consumed in both the 1465 # unshard and default streams for `NO_SHARD`. For sharded strategies, 1466 # the current stream here is the unshard stream, and for `NO_SHARD`, 1467 # it is the default stream. For `NO_SHARD`, only recording for the 1468 # default stream suffices since the default stream waits for the 1469 # unshard stream. 1470 _no_dispatch_record_stream( 1471 self.flat_param._mp_shard, self._device_handle.current_stream() # type: ignore[attr-defined] 1472 ) 1473 _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] 1474 1475 @torch.no_grad() 1476 def unshard_grad(self): 1477 """ 1478 Unshard the handle's ``FlatParameter``'s gradient. 1479 1480 If all ranks have 1481 ``None`` gradient, then all original parameters will as well. This 1482 method performs an all-reduce and an all-gather. The additional 1483 all-reduce is tolerable since this method is not meant to be used on 1484 the computation critical path. 1485 1486 Postcondition: ``_saved_grad_shard`` is defined and contains the value 1487 to set ``flat_param.grad`` after gradients are resharded. 1488 """ 1489 if not self.uses_sharded_strategy: 1490 self._use_unsharded_grad_views() 1491 return 1492 flat_param = self.flat_param 1493 self._check_unsharded(flat_param) 1494 1495 # Check if all ranks have a `None` gradient 1496 num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device) 1497 num_grad_none[0] = flat_param.grad is None 1498 dist.all_reduce(num_grad_none, group=self.process_group) 1499 if num_grad_none[0] == self.world_size: 1500 flat_param._saved_grad_shard = None # type: ignore[assignment] 1501 self._use_unsharded_grad_views() 1502 return 1503 1504 if flat_param.grad is None: 1505 # In the case that only some ranks have `None` gradient, we use 1506 # zeros to approximate as a best effort attempt 1507 if self._debug_level == dist.DebugLevel.INFO: 1508 warnings.warn( 1509 f"[Rank {self.rank}] Only some but not all ranks have a " 1510 "`None` `FlatParameter` gradient, so FSDP is using zeros to " 1511 "approximate those ranks' sharded gradients being `None`" 1512 ) 1513 flat_param._saved_grad_shard = None # type: ignore[assignment] 1514 sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined] 1515 else: 1516 self._check_sharded(flat_param.grad) 1517 flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] 1518 sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] 1519 padded_unsharded_grad = torch.empty( 1520 flat_param._padded_unsharded_size, # type: ignore[attr-defined] 1521 device=self.device, 1522 dtype=sharded_grad.dtype, 1523 ) 1524 dist.all_gather_into_tensor( 1525 padded_unsharded_grad, sharded_grad, self.process_group 1526 ) 1527 unsharded_size = self.flat_param._unpadded_unsharded_size 1528 flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view( 1529 unsharded_size 1530 ) 1531 self._use_unsharded_grad_views() 1532 1533 def reshard_grad(self): 1534 if self._use_orig_params: 1535 self._use_sharded_grad_views() 1536 if not self.uses_sharded_strategy: 1537 return 1538 self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined] 1539 delattr(self.flat_param, "_saved_grad_shard") 1540 1541 def prepare_gradient_for_backward(self): 1542 """ 1543 Prepare the gradient for the backward computation. 1544 1545 This is done by saving and clearing any existing sharded gradient 1546 in ``.grad`` to enable computing a new unsharded gradient. 1547 """ 1548 _p_assert( 1549 self._training_state 1550 in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), 1551 "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", 1552 ) 1553 flat_param = self.flat_param 1554 if flat_param.grad is not None and ( 1555 flat_param.grad.size() != flat_param._unpadded_unsharded_size 1556 or flat_param.grad.device != flat_param.device # grad on CPU 1557 ): 1558 self._check_on_compute_device(self.flat_param) 1559 grad_offloaded = flat_param.grad.device != self.device 1560 _p_assert( 1561 not grad_offloaded or self._offload_params, 1562 f"Expects the sharded gradient to be on {self.device} " 1563 f"but got {flat_param.grad.device}", 1564 ) 1565 prev_iter_synced_gradients = ( 1566 flat_param.grad.size() 1567 == flat_param._local_shard.size() # type: ignore[attr-defined] 1568 ) 1569 if prev_iter_synced_gradients: 1570 # TODO (awgu): Gradient accumulation outside `no_sync()` 1571 # does not work with CPU offloading. The issue should be 1572 # that, in the post-backward hook, we cannot do an addition 1573 # between a CPU tensor (the existing sharded gradient) and 1574 # a GPU tensor (the new sharded gradient). 1575 if not grad_offloaded: 1576 flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] 1577 sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] 1578 else: 1579 _p_assert( 1580 hasattr(flat_param, "_cpu_grad"), 1581 "`_cpu_grad` should be defined if the gradient is on CPU", 1582 ) 1583 sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined] 1584 # If user specified to keep the gradient in low precision, then 1585 # the gradient may still be of the low precision dtype if the 1586 # user did not set the gradient to `None` after the previous 1587 # backward, in which case FSDP should cast back to the full 1588 # precision dtype so that FSDP can accumulate in that dtype in 1589 # the post-backward hook and assign to `.grad` in that dtype in 1590 # the post-backward callback. 1591 local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined] 1592 if ( 1593 self._keep_low_precision_grads 1594 and sharded_grad.dtype != local_shard_dtype 1595 ): 1596 sharded_grad.data = sharded_grad.to(local_shard_dtype) 1597 else: 1598 padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] 1599 _p_assert( 1600 flat_param.grad.size() == padded_unsharded_size, 1601 "Expects `.grad` to be the unsharded gradient in " 1602 f"`no_sync()` with size {padded_unsharded_size} " 1603 f"but got size {flat_param.grad.size()}", 1604 ) 1605 flat_param.grad = None 1606 1607 def prepare_gradient_for_optim(self): 1608 """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute.""" 1609 1610 def cast_grad_to_param_dtype_if_needed(flat_param): 1611 # TODO (rohan-varma): test for full precision with keep_low_precision_grads 1612 if not self._force_full_precision and self._keep_low_precision_grads: 1613 _p_assert(flat_param.grad is not None, "Unexpected None grad!") 1614 if flat_param.grad.dtype != self._fwd_bwd_param_dtype: 1615 flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype) 1616 if self._use_orig_params: 1617 self._use_sharded_grad_views() 1618 1619 flat_param = self.flat_param 1620 # TODO (awgu): We should replace these conditional checks to encode 1621 # the logical intention more directly. 1622 if hasattr(flat_param, "_cpu_grad"): 1623 # NOTE: This branch includes `NO_SHARD`. 1624 self._check_sharded(flat_param) 1625 self._check_on_cpu(flat_param) 1626 flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined] 1627 cast_grad_to_param_dtype_if_needed(flat_param) 1628 elif hasattr(flat_param, "_saved_grad_shard"): 1629 self._check_sharded(flat_param) 1630 self._check_on_compute_device(flat_param) 1631 if flat_param._saved_grad_shard is not None: 1632 self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined] 1633 # If no sharded gradient was computed this iteration, then there is 1634 # no need to forward `_saved_grad_shard` to `grad` 1635 if flat_param._post_backward_called: # type: ignore[attr-defined] 1636 flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] 1637 if flat_param.grad is not None: 1638 cast_grad_to_param_dtype_if_needed(flat_param) 1639 else: 1640 _p_assert( 1641 not self.uses_sharded_strategy 1642 or not flat_param._post_backward_called, # type: ignore[attr-defined] 1643 "All sharded parameters that received a gradient in the " 1644 "post-backward should use `_saved_grad_shard`", 1645 ) 1646 # Delete `_saved_grad_shard` since its existence indicates a previous 1647 # gradient to accumulate with in the post-backward hook 1648 if hasattr(flat_param, "_saved_grad_shard"): 1649 delattr(flat_param, "_saved_grad_shard") 1650 1651 @contextlib.contextmanager 1652 def to_cpu(self): 1653 """ 1654 Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit. 1655 1656 For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter 1657 since (1) there is no reason to include the padding in the copy and (2) 1658 there is no use case for the sharded flat parameter. 1659 1660 Precondition: ``self.flat_param`` 's data is the unpadded unsharded 1661 flat parameter on the compute device, and the handle uses a sharded 1662 strategy. 1663 Postcondition: Same as the precondition. 1664 """ 1665 self._check_sharded_strategy() 1666 _p_assert( 1667 self.flat_param.size() == self.flat_param._unpadded_unsharded_size, 1668 f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", 1669 ) 1670 self._check_on_compute_device(self.flat_param) 1671 # Check that the unpadded unsharded flat parameter is a view into the 1672 # padded unsharded flat parameter as expected 1673 # NOTE: This check is not strictly needed for correctness but is a 1674 # useful sanity check since the tensor should only be used internally. 1675 _p_assert( 1676 _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()), 1677 "Expects the unpadded parameter to be a view into the padded parameter", 1678 ) 1679 self.flat_param_to(torch.device("cpu")) 1680 self._free_unsharded_flat_param() 1681 try: 1682 yield 1683 finally: 1684 _p_assert( 1685 self.flat_param.size() == self.flat_param._unpadded_unsharded_size, 1686 f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", 1687 ) 1688 padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() 1689 # Copy from CPU to the compute device 1690 padded_unsharded_flat_param[: self.flat_param.numel()].copy_( 1691 self.flat_param 1692 ) 1693 self._use_unsharded_flat_param(padded_unsharded_flat_param) 1694 1695 def reshard(self, free_unsharded_flat_param: bool): 1696 """ 1697 Run the reshard logic. 1698 1699 This includes freeing the unsharded flat 1700 parameter if ``free_unsharded_flat_param`` and switching to using the 1701 sharded flat parameter. Note that this also implicitly offloads 1702 the sharded flat parameter (if CPU offload is enabled) by pointing 1703 it to the ``_local_shard`` attribute which resides on CPU. 1704 """ 1705 # Switch to the sharded `FlatParameter` before freeing to prevent 1706 # "use-after-free"-type bugs with external profiling tools, where for 1707 # `use_orig_params=True`, the `param` does not point to valid memory 1708 # when setting `param.data = ...` in `_use_sharded_views()`. 1709 self._use_sharded_flat_param() 1710 if free_unsharded_flat_param: 1711 self._free_unsharded_flat_param() 1712 1713 def post_reshard(self): 1714 """ 1715 Run the post-reshard logic. 1716 1717 This includes freeing any memory that 1718 can now be freed given that the ``FlatParameter`` points to the full 1719 precision sharded flat parameter. 1720 1721 Precondition: ``self.flat_param`` 's data points to the full precision 1722 sharded flat parameter. 1723 """ 1724 # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it 1725 # is also the low precision *unsharded* flat parameter. Hence, we delay 1726 # the free until the reshard. 1727 if ( 1728 self._uses_param_mixed_precision 1729 and not self.uses_sharded_strategy 1730 and not self._force_full_precision # did not use the low precision shard 1731 ): 1732 self._free_low_precision_sharded_param() 1733 1734 def _free_unsharded_flat_param(self): 1735 """ 1736 Free the padded unsharded flat parameter. We allow this 1737 function to be called even when storage is not allocated 1738 1739 The tensor to free depends 1740 on the calling context since the unshard may have forced full 1741 precision, in which case a different tensor is used. 1742 """ 1743 self._check_sharded_strategy() 1744 unsharded_flat_param = self._get_padded_unsharded_flat_param() 1745 self._check_on_compute_device(unsharded_flat_param) 1746 # Do not free the memory until all ops in the current stream finish 1747 _no_dispatch_record_stream( 1748 unsharded_flat_param, self._device_handle.current_stream() 1749 ) 1750 _free_storage(unsharded_flat_param) 1751 1752 def _use_sharded_flat_param(self) -> None: 1753 """Switches to using the sharded flat parameter.""" 1754 flat_param = self.flat_param 1755 if self._use_orig_params: 1756 in_forward = self._training_state == HandleTrainingState.FORWARD 1757 skip_use_sharded_views = ( 1758 torch.is_grad_enabled() 1759 and in_forward 1760 and self._sharding_strategy 1761 in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES 1762 ) 1763 # Only incur the extra `.data` call if needed 1764 if skip_use_sharded_views: 1765 unsharded_flat_param = flat_param.data 1766 if self._offload_params: 1767 device = flat_param._local_shard.device # type: ignore[attr-defined] 1768 _p_assert( 1769 device == torch.device("cpu"), 1770 f"Expects the local shard to be on CPU but got {device}", 1771 ) 1772 flat_param.data = flat_param._local_shard # type: ignore[attr-defined] 1773 if self._use_orig_params: 1774 if skip_use_sharded_views: # type: ignore[possibly-undefined] 1775 self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined] 1776 else: 1777 self._use_sharded_views() 1778 # For the post-forward reshard, we may try to use sharded gradient 1779 # views (or unsharded gradient views if a gradient was accumulated 1780 # in `no_sync()`), but for the post-backward reshard, we delay the 1781 # call to after the reduce-scatter. 1782 if ( 1783 in_forward # type: ignore[possibly-undefined] 1784 # Skip using gradient views if skipped using sharded views 1785 # since exposing unsharded parameters with sharded gradients 1786 # may be confusing to the user 1787 and not self._skipped_use_sharded_views 1788 ): 1789 # TODO: Change `_unpadded_unsharded_size` if we change the 1790 # gradient to be computed directly with padding. 1791 accumulated_grad_in_no_sync = ( 1792 flat_param.grad is not None 1793 and self.uses_sharded_strategy 1794 and flat_param.grad.shape == flat_param._unpadded_unsharded_size 1795 ) 1796 if accumulated_grad_in_no_sync: 1797 self._use_unsharded_grad_views() 1798 else: 1799 self._use_sharded_grad_views() 1800 1801 ######### 1802 # VIEWS # 1803 ######### 1804 @no_type_check 1805 def _get_unflat_views_unaligned( 1806 self, 1807 tensor: Optional[torch.Tensor] = None, 1808 ) -> Iterator[Tensor]: 1809 """ 1810 Return unflattened ``Tensor`` views into ``tensor``. 1811 1812 If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based 1813 on ``flat_param`` 's metadata. 1814 1815 Examples for ``tensor`` include ``flat_param.grad`` or unsharded 1816 tensor optimizer state. 1817 """ 1818 flat_param = self.flat_param 1819 if tensor is None: 1820 tensor = flat_param 1821 views = ( 1822 _ext_post_unflatten_transform( 1823 subtensor.view(shape), 1824 param_extension, 1825 self._fsdp_extension, 1826 ) 1827 for (subtensor, shape, param_extension) in zip( 1828 torch.split(tensor, flat_param._numels, dim=0), 1829 flat_param._shapes, 1830 flat_param._param_extensions, 1831 ) 1832 ) 1833 return views 1834 1835 @no_type_check 1836 def _get_unflat_views_aligned( 1837 self, 1838 tensor: Optional[Tensor] = None, 1839 ) -> List[Tensor]: 1840 """ 1841 Return unflattened ``Tensor`` views into ``tensor`` with handling for padding. 1842 1843 This method has the same contract as :meth:`_get_unflat_views_unaligned` 1844 except it checks for ``None`` placeholders representing padding for 1845 alignment, which may incur slightly more CPU overhead. 1846 """ 1847 flat_param = self.flat_param 1848 if tensor is None: 1849 tensor = flat_param 1850 splits: List[Tensor] = torch.split( 1851 tensor, flat_param._numels_with_padding, dim=0 1852 ) 1853 idx = 0 1854 views: List[Tensor] = [] 1855 for split, is_padding in zip(splits, flat_param._is_padding_mask): 1856 if is_padding: 1857 continue 1858 views.append( 1859 _ext_post_unflatten_transform( 1860 split.view(flat_param._shapes[idx]), 1861 flat_param._param_extensions[idx], 1862 self._fsdp_extension, 1863 ) 1864 ) 1865 idx += 1 1866 return views 1867 1868 @no_type_check 1869 @torch.enable_grad() 1870 def _use_unsharded_views(self, as_params: bool) -> None: 1871 """ 1872 Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it. 1873 1874 Args: 1875 as_params (bool): If ``True``, then registers the original 1876 parameters as ``nn.Parameter`` s; if ``False``, then registers 1877 the original parameters only as ``Tensor`` s. ``False`` should 1878 be used during forward/backward computation and when hiding the 1879 original parameters from :meth:`nn.Module.named_parameters`. 1880 1881 Note: 1882 when prefetching for next forward, current forward may be 1883 annotated with `@torch.no_grad()` 1884 `@torch.enable_grad()` ensures non-empty `view.grad_fn` 1885 otherwise `_post_backward_hook` will not get called 1886 """ 1887 flat_param = self.flat_param 1888 self._check_unsharded(flat_param) 1889 views = self._get_unflat_views() 1890 from torch.distributed.tensor import DTensor 1891 1892 for i, (view, (param_name, module, _)) in enumerate( 1893 zip(views, flat_param._param_infos) 1894 ): 1895 if self._use_orig_params and as_params: 1896 if type(view) is DTensor: 1897 # A `DTensor` `view` is not compatible with assigning 1898 # `param.data = view`, so we cannot preserve the parameter 1899 # variable. 1900 self._setattr_param( 1901 module, 1902 param_name, 1903 nn.Parameter(view, requires_grad=flat_param.requires_grad), 1904 ) 1905 continue 1906 param = self.flat_param._params[i] 1907 self._setattr_param(module, param_name, param) 1908 param.data = view 1909 elif as_params: 1910 self._setattr_param( 1911 module, 1912 param_name, 1913 nn.Parameter(view, requires_grad=flat_param.requires_grad), 1914 ) 1915 else: # `as_params=False` 1916 param_var: Tensor = view 1917 if self._use_orig_params: 1918 if self._training_state == HandleTrainingState.FORWARD: 1919 # Save the `Tensor` for the pre-backward 1920 self.flat_param._tensors[i] = view # save for pre-backward 1921 elif self._training_state == HandleTrainingState.BACKWARD_PRE: 1922 # Use the saved `Tensor` variable from the forward to 1923 # preserve the autograd graph so that the post-backward 1924 # hook fires (e.g. for reentrant AC) 1925 tensor = self.flat_param._tensors[i] 1926 tensor.data = view 1927 param_var = tensor 1928 self._setattr_tensor(module, param_name, param_var) 1929 if ( 1930 self._use_orig_params 1931 and self._training_state == HandleTrainingState.FORWARD 1932 ): 1933 module._parameters[param_name] = param_var 1934 for i, ( 1935 param_name, 1936 module, 1937 _, 1938 prim_param_name, 1939 prim_module, 1940 _, 1941 ) in enumerate(self.flat_param._shared_param_infos): 1942 prim_param: Union[Tensor, nn.Parameter] = getattr( 1943 prim_module, prim_param_name 1944 ) 1945 _p_assert( 1946 not as_params or isinstance(prim_param, nn.Parameter), 1947 f"as_params={as_params} type(prim_param)={type(prim_param)}", 1948 ) 1949 if self._use_orig_params and as_params: 1950 shared_param = self.flat_param._shared_params[i] 1951 self._setattr_param(module, param_name, shared_param) 1952 shared_param.data = prim_param 1953 elif as_params: 1954 self._setattr_param(module, param_name, prim_param) 1955 else: 1956 self._setattr_tensor(module, param_name, prim_param) 1957 if ( 1958 self._use_orig_params 1959 and self._training_state == HandleTrainingState.FORWARD 1960 ): 1961 module._parameters[param_name] = prim_param 1962 1963 @no_type_check 1964 def _use_unsharded_grad_views(self) -> None: 1965 """ 1966 Unflatten the unsharded flat parameter's gradient. 1967 1968 The original parameter variables' gradients are set to be views into 1969 the unsharded flat parameter's gradient. 1970 """ 1971 # Expects the gradient to be in `flat_param.grad` 1972 if self.flat_param.grad is None: 1973 for param in chain(self.flat_param._params, self.flat_param._shared_params): 1974 param.grad = None 1975 return 1976 self._check_unsharded(self.flat_param.grad) 1977 views = self._get_unflat_views(self.flat_param.grad) 1978 for i, (view, (param_name, module, _)) in enumerate( 1979 zip(views, self.flat_param._param_infos) 1980 ): 1981 _p_assert( 1982 hasattr(module, param_name), 1983 f"{self.flat_param._fqns[i]} is missing", 1984 ) 1985 param = getattr(module, param_name) 1986 if ( 1987 param.shape != view.shape 1988 or param.dtype != view.dtype 1989 or param.device != view.device 1990 ): 1991 # NOTE: This is a hack using `.data` to side step the check 1992 # that parameter/gradient sizes/dtypes/devices match. From 1993 # calling `reshard()`, `param` has the sharded size, has the 1994 # full precision dtype, and if CPU offloading is enabled, is on 1995 # CPU. Thus, one or more of the following cases can hold when 1996 # in `no_sync()`, where `view` is the original parameter's 1997 # gradient: 1998 # 1. `view` can have the unsharded size. 1999 # 2. `view` can have the parameter low precision dtype. 2000 # 3. `view` can be on GPU. 2001 if param.grad is None: 2002 param.grad = torch.empty_like(param) 2003 param.grad.data = view 2004 else: 2005 param.grad = view 2006 for i, ( 2007 param_name, 2008 module, 2009 module_name, 2010 prim_param_name, 2011 prim_module, 2012 _, 2013 ) in enumerate(self.flat_param._shared_param_infos): 2014 _p_assert( 2015 hasattr(module, param_name), 2016 f"{module_name + '.' + param_name if module_name else param_name} is missing", 2017 ) # did not save FQN info in `_shared_param_infos` 2018 param = getattr(module, param_name) 2019 prim_param = getattr(prim_module, prim_param_name) 2020 if ( 2021 param.shape != prim_param.grad.shape 2022 or param.dtype != prim_param.grad.dtype 2023 or param.device != prim_param.grad.device 2024 ): 2025 # NOTE: This is the same hack to use `.data` to side step the 2026 # size check. 2027 if param.grad is None: 2028 param.grad = torch.empty_like(param) 2029 param.grad.data = prim_param.grad 2030 else: 2031 param.grad = prim_param.grad 2032 2033 @contextlib.contextmanager 2034 def unflatten_as_params(self) -> Generator: 2035 """ 2036 Unflatten the original parameters. 2037 2038 The function assumes that the flat parameter is unsharded. When in the context, 2039 unflattens the original parameters as ``nn.Parameter`` views into the 2040 flat parameter, and after the context, restores the original parameters 2041 as ``Tensor`` views into the flat parameter. 2042 """ 2043 self._use_unsharded_views(as_params=True) 2044 try: 2045 yield 2046 finally: 2047 self._use_unsharded_views(as_params=False) 2048 2049 @no_type_check 2050 @torch.no_grad() 2051 def _use_sharded_views(self) -> None: 2052 """ 2053 Set the original parameter variables' data to be flattened views into the sharded flat parameter. 2054 2055 The views are kept as flattened to simplify the case where a parameter 2056 is sharded across ranks. Parameters whose data is not present in the 2057 sharded flat parameter have their data set to a size-0 empty tensor. We 2058 do not delete them to ensure to preserve expected behaviors like model 2059 printability. Parameters whose data is present must preserve their 2060 variables to be passable to an optimizer. 2061 """ 2062 self._unsharded_flat_param_for_skipped_views = None 2063 if not self.uses_sharded_strategy: 2064 # For `NO_SHARD`, use the *unflattened* unsharded views since we 2065 # have the unsharded parameter 2066 self._use_unsharded_views(as_params=True) 2067 return 2068 flat_param = self.flat_param 2069 self._check_sharded(flat_param) 2070 # Construct once and reuse for all parameters not in the local shard 2071 size_0_empty_tensor = torch.empty( 2072 0, 2073 dtype=self.flat_param.dtype, # in case `flat_param` changed dtype 2074 device=self.flat_param.device, 2075 requires_grad=False, 2076 ) 2077 for param, shard_param_info, (param_name, module, _) in zip( 2078 flat_param._params, flat_param._shard_param_infos, flat_param._param_infos 2079 ): 2080 self._setattr_param(module, param_name, param) 2081 if not shard_param_info.in_shard: 2082 # Allow the original data to be freed via garbage collection 2083 param.data = size_0_empty_tensor 2084 else: 2085 offset = shard_param_info.offset_in_shard 2086 numel_in_shard = shard_param_info.numel_in_shard 2087 param.data = flat_param[offset : offset + numel_in_shard] 2088 assert self.flat_param._shared_params is not None 2089 for i, ( 2090 param, 2091 (param_name, module, _, prim_param_name, prim_module, _), 2092 ) in enumerate( 2093 zip(self.flat_param._shared_params, self.flat_param._shared_param_infos) 2094 ): 2095 self._setattr_param(module, param_name, param) 2096 prim_param = getattr(prim_module, prim_param_name) 2097 param.data = prim_param # could be both empty and non-empty 2098 if self._training_state == HandleTrainingState.BACKWARD_POST: 2099 # Clear the saved `Tensor`s since they are unneeded now 2100 for i in range(len(self.flat_param._tensors)): 2101 self.flat_param._tensors[i] = None 2102 2103 @no_type_check 2104 @torch.no_grad() 2105 def _use_sharded_grad_views(self) -> None: 2106 """ 2107 Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient. 2108 2109 This is a no-op if there is no gradient. 2110 2111 Parameters whose data is not present in the sharded flat parameter and 2112 parameters with ``requires_grad=False`` have their gradients set to 2113 ``None``. Since the gradient variables do not need to be preserved, 2114 this method does not manipulate existing ``Tensor`` data directly and 2115 creates new ``Tensor`` variables instead. 2116 """ 2117 flat_param = self.flat_param 2118 self._check_sharded(flat_param) 2119 grad = self.sharded_grad 2120 if grad is None: 2121 for param in chain(flat_param._params, flat_param._shared_params): 2122 param.grad = None 2123 return 2124 self._check_sharded(grad) 2125 for param, shard_param_info, is_grad_none in zip( 2126 flat_param._params, 2127 flat_param._shard_param_infos, 2128 flat_param._is_grad_none_mask, 2129 ): 2130 if not shard_param_info.in_shard: 2131 param.grad = None 2132 else: 2133 numel_in_shard = shard_param_info.numel_in_shard 2134 if param.requires_grad and not is_grad_none: 2135 offset = shard_param_info.offset_in_shard 2136 if self._keep_low_precision_grads or param.dtype != grad.dtype: 2137 # NOTE: This is a hack using `.data` to side step the 2138 # check that parameter/gradient dtypes match. Here, 2139 # `param` has full precision; `grad` has low precision. 2140 if param.grad is None: 2141 # `.grad` must have the same shape as `param` 2142 param.grad = torch.empty_like(param) 2143 param.grad.data = grad[ 2144 offset : offset + numel_in_shard 2145 ].reshape(param.shape) 2146 else: 2147 param.grad = grad[offset : offset + numel_in_shard].reshape( 2148 param.shape 2149 ) 2150 else: 2151 param.grad = None 2152 assert flat_param._shared_params is not None 2153 for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate( 2154 zip(flat_param._shared_params, flat_param._shared_param_infos) 2155 ): 2156 in_sharded_flat_param = hasattr(prim_module, prim_param_name) 2157 if in_sharded_flat_param and param.requires_grad: 2158 prim_param = getattr(prim_module, prim_param_name) 2159 param.grad = prim_param.grad # share the same reference 2160 else: 2161 param.grad = None 2162 2163 @no_type_check 2164 @torch.no_grad() 2165 def _writeback_orig_params(self) -> bool: 2166 """ 2167 Write back any parameters that changed storage to the handle's ``FlatParameter``. 2168 2169 Iterates over the original parameters and writes back any parameters 2170 that changed storages (due to a non-inplace operator) to the handle's 2171 ``FlatParameter``. This method preserves the ``FlatParameter` 's 2172 device even if an original parameter's device changes. 2173 2174 Raises: 2175 RuntimeError: If an original parameter or gradient changes storages 2176 but no longer has the expected flattened shape. 2177 Returns: ``True`` if some writeback happened, and ``False`` otherwise. 2178 """ 2179 if ( 2180 self.uses_sharded_strategy 2181 and not self.is_sharded(self.flat_param) 2182 and not self._skipped_use_sharded_views 2183 ): 2184 # For `NO_SHARD`, we may still need to writeback 2185 return False 2186 flat_param = self.flat_param 2187 wroteback = False 2188 if self._skipped_use_sharded_views and self.uses_sharded_strategy: 2189 # NOTE: We must use the unsharded flat parameter from which the 2190 # unsharded views were computed, not the one from the current 2191 # calling context (`_get_padded_unsharded_flat_param()`) since that 2192 # may be different (e.g. the model changed from train to eval). 2193 flat_param_tensor = self._unsharded_flat_param_for_skipped_views 2194 _p_assert( 2195 _data_ptr_allocated(flat_param_tensor), 2196 "If skipped using sharded views, the unsharded flat parameter " 2197 "should be allocated", 2198 ) 2199 else: 2200 flat_param_tensor = flat_param 2201 # NOTE: Since this method is called in the pre-unshard, which is only 2202 # called during computation in the pre-forward or pre-backward, the 2203 # sharded gradient should be guaranteed to be in `.grad`, not in 2204 # `._saved_grad_shard`. 2205 flat_param_grad = ( 2206 flat_param.grad 2207 if self.uses_sharded_strategy or not self._offload_params 2208 else flat_param._cpu_grad 2209 ) 2210 for i, ( 2211 param, 2212 (in_shard, offset_in_shard, numel_in_shard, _, _), 2213 (param_name, module, _), 2214 ) in enumerate( 2215 zip( 2216 flat_param._params, 2217 flat_param._shard_param_infos, 2218 flat_param._param_infos, 2219 ) 2220 ): 2221 if not in_shard: 2222 continue 2223 if not hasattr(module, param_name): 2224 # Do not writeback if original parameters are deregistered 2225 # (e.g. during model checkpointing) 2226 continue 2227 2228 # Check for parameter writeback 2229 if self._skipped_use_sharded_views: 2230 param = flat_param._tensors[i] 2231 _p_assert( 2232 param is not None, 2233 f"Expects to have saved tensor for {flat_param._fqns[i]}", 2234 ) 2235 param_changed = getattr(module, param_name) is not param 2236 needs_param_writeback = ( 2237 param_changed # changed parameter variable itself 2238 or not _same_storage(param, flat_param_tensor) 2239 ) 2240 if self._skipped_use_sharded_views and ( 2241 param_changed or needs_param_writeback 2242 ): 2243 raise AssertionError( 2244 "FSDP does not support changing the parameters between " 2245 f"forward and backward for {self._sharding_strategy}" 2246 ) 2247 if param_changed: 2248 # NOTE: The gradient is not preserved after a parameter change. 2249 param = getattr(module, param_name) 2250 flat_param._params[i] = param 2251 if needs_param_writeback: 2252 expected_shape = torch.Size([numel_in_shard]) 2253 self._writeback_tensor( 2254 param, flat_param, i, expected_shape, offset_in_shard, True 2255 ) 2256 wroteback = True 2257 2258 # Check for gradient writeback 2259 if self._skipped_use_sharded_views: 2260 # Skip the writeback check because we do not expose gradients 2261 # when we skipped using sharded views 2262 continue 2263 if param.grad is None and flat_param.grad is not None: 2264 expected_shape = torch.Size([numel_in_shard]) 2265 self._writeback_tensor( 2266 None, flat_param.grad, i, expected_shape, offset_in_shard, False 2267 ) 2268 elif param.grad is not None: 2269 # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in 2270 # memory and owns the gradient storage, so it will never 2271 # require gradient writeback. 2272 if not self.uses_sharded_strategy and self._offload_params: 2273 # Explicitly continue to handle the case of `no_sync()`, 2274 # where `param.grad` is a view into the GPU gradient 2275 # referenced by `flat_param.grad`, while `flat_param_grad` 2276 # is `flat_param._cpu_grad`, which is on CPU 2277 continue 2278 2279 needs_grad_writeback = flat_param_grad is None or not _same_storage( 2280 param.grad, flat_param_grad 2281 ) 2282 if needs_grad_writeback: 2283 if flat_param_grad is None: 2284 flat_param_grad = torch.zeros_like(flat_param) 2285 expected_shape = torch.Size([numel_in_shard]) 2286 self._writeback_tensor( 2287 param.grad, 2288 flat_param_grad, 2289 i, 2290 expected_shape, 2291 offset_in_shard, 2292 False, 2293 ) 2294 flat_param.grad = flat_param_grad 2295 flat_param_grad = flat_param.grad 2296 2297 # TODO: If we want to handle shared parameters, we need to re-generate 2298 # the shared parameter data structures in case sharedness changed. 2299 for i, ( 2300 param_name, 2301 module, 2302 _, 2303 prim_param_name, 2304 prim_module, 2305 _, 2306 ) in enumerate(flat_param._shared_param_infos): 2307 if getattr(module, param_name) is not getattr(prim_module, prim_param_name): 2308 raise NotImplementedError( 2309 "Changing shared parameters is not supported yet" 2310 ) 2311 return wroteback 2312 2313 def _writeback_tensor( 2314 self, 2315 src_tensor: Optional[Tensor], 2316 dst_tensor: Tensor, 2317 tensor_index: int, 2318 expected_shape: torch.Size, 2319 offset: int, 2320 is_param: bool, # else gradient 2321 ) -> None: 2322 """ 2323 Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``. 2324 2325 ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if 2326 ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing 2327 instead of copying. ``tensor_index`` gives the index of ``src_tensor`` 2328 in the metadata structures. 2329 2330 Raises: 2331 RuntimeError: If the ``src_tensor`` does not have the expected 2332 shape. 2333 """ 2334 _p_assert( 2335 len(expected_shape) == 1, 2336 f"Expects a 1D expected shape but got {expected_shape}", 2337 ) 2338 if self._debug_level == dist.DebugLevel.INFO: 2339 rank = self.rank if hasattr(self, "rank") else dist.get_rank() 2340 src_shape = src_tensor.shape if src_tensor is not None else None 2341 src_device = src_tensor.device if src_tensor is not None else None 2342 warnings.warn( 2343 f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs " 2344 f"writeback in {self._training_state}\n" 2345 f"expected shape={expected_shape} shape={src_shape} " 2346 f"expected device={dst_tensor.device} device={src_device}" 2347 ) 2348 if src_tensor is not None and src_tensor.shape != expected_shape: 2349 # NOTE: Gradient shape mismatch is not possible in practice since 2350 # the gradient shape is enforced to match that of the parameter and 2351 # we already check for parameter shape mismatch. 2352 raise RuntimeError( 2353 f"Cannot writeback when the {'parameter' if is_param else 'gradient'} " 2354 f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}" 2355 ) 2356 if src_tensor is not None: 2357 dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor) 2358 else: 2359 dst_tensor[offset : offset + expected_shape.numel()].zero_() 2360 assert self.flat_param._is_grad_none_mask is not None 2361 self.flat_param._is_grad_none_mask[tensor_index] = True 2362 2363 def _reset_flat_param_grad_info_if_needed(self): 2364 """ 2365 Reset ``flat_param.grad`` if needed. 2366 2367 When ``use_orig_params=True``: 2368 (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the 2369 original parameters' ``.grad`` are ``None``, and 2370 (2) sets ``flat_param.requires_grad=False`` if *none* of the original 2371 parameters require gradient. 2372 For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in 2373 which case we want to free the gradients as soon after the 2374 ``zero_grad()`` call as possible. 2375 """ 2376 if not self._use_orig_params: 2377 return 2378 flat_param = self.flat_param 2379 assert flat_param._params is not None # mypy 2380 all_grad_none = True 2381 requires_grad = False 2382 for param in flat_param._params: 2383 all_grad_none &= param.grad is None 2384 requires_grad |= param.requires_grad 2385 if all_grad_none: 2386 flat_param.grad = None 2387 # As long as one parameter requires gradient, then the flat parameter 2388 # must require gradient 2389 flat_param.requires_grad = requires_grad 2390 2391 def _deregister_orig_params(self): 2392 for param_info in self.flat_param._param_infos: 2393 param_name, module, _ = param_info 2394 if hasattr(module, param_name): 2395 delattr(module, param_name) 2396 for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos: 2397 if hasattr(module, param_name): 2398 delattr(module, param_name) 2399 2400 ########### 2401 # HELPERS # 2402 ########### 2403 def flat_param_to(self, *args, **kwargs): 2404 """Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" 2405 self.flat_param.data = self.flat_param.to(*args, **kwargs) 2406 if self._use_orig_params: 2407 # Refresh the views because their storage may have changed 2408 if self.is_sharded(self.flat_param): 2409 self._use_sharded_views() 2410 else: 2411 self._use_unsharded_views(as_params=True) 2412 2413 def _get_modules(self) -> Set[nn.Module]: 2414 """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter.""" 2415 return {pi.module for pi in self.flat_param._param_infos}.union( 2416 {spi.module for spi in self.flat_param._shared_param_infos} 2417 ) 2418 2419 def is_sharded(self, tensor: Tensor) -> bool: 2420 """ 2421 Return whether ``tensor`` is *currently* sharded. 2422 2423 For ``NO_SHARD``, we choose to have this always return ``False`` for clarity. 2424 """ 2425 if ( 2426 not hasattr(self.flat_param, "_sharded_size") 2427 or not self.uses_sharded_strategy 2428 ): 2429 # `_sharded_size` is defined iff `handle.shard()` has been called 2430 return False 2431 sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] 2432 return tensor.size() == sharded_size 2433 2434 def param_module_names(self) -> Iterator[Tuple[str, str]]: 2435 shared_param_infos = [ 2436 ParamInfo(param_name, module, module_name) 2437 for ( 2438 param_name, 2439 module, 2440 module_name, 2441 _, 2442 _, 2443 _, 2444 ) in self.flat_param._shared_param_infos 2445 ] 2446 for param_info in chain(self.flat_param._param_infos, shared_param_infos): 2447 param_name, _, module_name = param_info # type: ignore[misc] 2448 yield (param_name, module_name) 2449 2450 def shared_param_module_names(self) -> Iterator[Tuple[str, str]]: 2451 for param_name, _, module_name in [ 2452 ParamInfo(param_name, module, module_name) 2453 for ( 2454 param_name, 2455 module, 2456 module_name, 2457 _, 2458 _, 2459 _, 2460 ) in self.flat_param._shared_param_infos 2461 ]: 2462 yield (param_name, module_name) 2463 2464 @property 2465 def _fqns_in_shard(self) -> List[str]: 2466 """Return the FQNs of the parameters present in this rank's shard.""" 2467 fqns_in_shard: List[str] = [] 2468 for fqn, shard_param_info in zip( 2469 self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined] 2470 ): 2471 if shard_param_info.in_shard: 2472 fqns_in_shard.append(fqn) 2473 return fqns_in_shard 2474 2475 @property 2476 def sharded_grad(self) -> Optional[Tensor]: 2477 """Return the handle's sharded gradient.""" 2478 flat_param = self.flat_param 2479 # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad` 2480 # - CPU offloading: `_cpu_grad` 2481 # - No CPU offloading + sharded strategies: `_saved_grad_shard` 2482 # - No CPU offloading + `NO_SHARD`: `grad` 2483 grad: Optional[Tensor] 2484 if hasattr(flat_param, "_cpu_grad"): 2485 grad = flat_param._cpu_grad # type: ignore[attr-defined] 2486 elif hasattr(flat_param, "_saved_grad_shard"): 2487 # In the post-backward hook, the sharded gradient is still in 2488 # `_saved_grad_shard`. 2489 grad = flat_param._saved_grad_shard # type: ignore[attr-defined] 2490 else: 2491 # If in IDLE or in FORWARD states, then there may be an 2492 # (accumulated) gradient. If accessed in IDLE, then this should 2493 # be due to re-registering the original parameters (e.g. in state 2494 # dict load). 2495 _p_assert( 2496 flat_param.grad is None 2497 or not self.uses_sharded_strategy 2498 or self._training_state 2499 in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE), 2500 "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` " 2501 "unless in IDLE or FORWARD", 2502 ) 2503 grad = flat_param.grad 2504 return grad 2505 2506 def _reset_is_grad_none(self) -> None: 2507 """ 2508 Reset ``_is_grad_none_mask`` as needed. 2509 2510 This method should only be 2511 called in the post-backward after gradient computation, in which case 2512 if a parameter requires gradient, then it will surely receive a 2513 gradient and we may reset its mask entry to ``False``. 2514 """ 2515 if not self._use_orig_params: 2516 return 2517 _p_assert( 2518 self._training_state == HandleTrainingState.BACKWARD_POST, 2519 "Expects to only be called in the post-backward after gradient computation", 2520 ) 2521 flat_param = self.flat_param 2522 assert flat_param._params is not None # mypy 2523 for i, param in enumerate(flat_param._params): # type: ignore[arg-type] 2524 # As long as the parameter requires gradient, it should receive a 2525 # meaningful gradient (even if the gradient happens to be zeros) 2526 if param.requires_grad: 2527 assert flat_param._is_grad_none_mask is not None # mypy 2528 flat_param._is_grad_none_mask[i] = False 2529 2530 ####################### 2531 # CHECKS & INVARIANTS # 2532 ####################### 2533 def _check_sharded_strategy(self): 2534 _p_assert(self.uses_sharded_strategy, "Expects sharded strategy") 2535 2536 def _check_on_compute_device(self, tensor: Tensor): 2537 _p_assert( 2538 tensor.device == self.device, 2539 f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}", 2540 ) 2541 2542 def _check_on_cpu(self, tensor: Tensor): 2543 _p_assert( 2544 tensor.device == torch.device("cpu"), 2545 f"Expects tensor to be on CPU but got {tensor.device}", 2546 ) 2547 2548 @staticmethod 2549 def _check_storage_freed(tensor: Tensor): 2550 # Compile does not resize during trace 2551 if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): 2552 _p_assert( 2553 _same_storage_size(tensor, 0), 2554 "Expects storage to be freed but got storage with size > 0", 2555 ) 2556 2557 @staticmethod 2558 def _check_storage_allocated(tensor: Tensor): 2559 _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated") 2560 2561 def _check_low_precision_shard(self): 2562 _p_assert( 2563 self._uses_param_mixed_precision, 2564 "Not using low precision for parameters", 2565 ) 2566 _p_assert( 2567 getattr(self.flat_param, "_mp_shard", None) is not None, 2568 "Expects `_mp_shard` to exist", 2569 ) 2570 device = self.flat_param._mp_shard.device # type: ignore[attr-defined] 2571 _p_assert( 2572 device == self.device, 2573 f"Expects the low precision shard to be on {self.device} but got {device}", 2574 ) 2575 2576 def _check_unsharded(self, tensor: Tensor): 2577 msg_prefix = "Expects tensor to be unsharded " 2578 _p_assert(tensor is not None, msg_prefix + "but got `None`") 2579 unsharded_size = self.flat_param._unpadded_unsharded_size 2580 _p_assert( 2581 tensor.size() == unsharded_size, 2582 msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", 2583 ) 2584 2585 def _check_sharded(self, tensor: Tensor): 2586 msg_prefix = "Expects tensor to be sharded " 2587 _p_assert(tensor is not None, msg_prefix + "but got `None`") 2588 sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] 2589 _p_assert( 2590 tensor.size() == sharded_size, 2591 msg_prefix + f"with size {sharded_size} but got {tensor.size()}", 2592 ) 2593 2594 ############## 2595 # PROPERTIES # 2596 ############## 2597 @property 2598 def uses_sharded_strategy(self) -> bool: 2599 return self._sharding_strategy != HandleShardingStrategy.NO_SHARD 2600 2601 @property 2602 def _uses_param_mixed_precision(self) -> bool: 2603 return self._fwd_bwd_param_dtype != self._orig_param_dtype 2604 2605 @property 2606 def _uses_reduce_mixed_precision(self) -> bool: 2607 return self._reduce_dtype != self._orig_param_dtype 2608 2609 @property 2610 def _force_full_precision(self) -> bool: 2611 return ( 2612 self._uses_param_mixed_precision or self._uses_reduce_mixed_precision 2613 ) and ( 2614 self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS 2615 or 2616 # Also disable mixed precision in model eval mode, if configured 2617 (not self._fully_sharded_module.training and self._use_full_prec_in_eval) 2618 ) 2619 2620 @property 2621 def _skipped_use_sharded_views(self) -> bool: 2622 """ 2623 This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``. 2624 2625 This returns if this handle is 2626 currently in a state where it has skipped using sharded views, in which 2627 case it can restore view invariants via ``_use_sharded_views()``. 2628 """ 2629 return self._unsharded_flat_param_for_skipped_views is not None 2630 2631 2632# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks. 2633def _unsafe_setattr_param( 2634 module: nn.Module, param_name: str, param: nn.Parameter 2635) -> None: 2636 module._parameters[param_name] = param 2637 # This bypasses any overrides in case `module` is an instance of an 2638 # `nn.Module` subclass 2639 super(nn.Module, module).__setattr__(param_name, param) 2640 2641 2642def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None: 2643 module._parameters.pop(param_name, None) 2644 # This bypasses any overrides in case `module` is an instance of an 2645 # `nn.Module` subclass 2646 super(nn.Module, module).__setattr__(param_name, tensor) 2647 2648 2649def _safe_setattr_tensor_or_param( 2650 module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter] 2651): 2652 # Call `delattr()` and `setattr()` to go through `nn.Module` checks 2653 if hasattr(module, param_name): 2654 delattr(module, param_name) 2655 setattr(module, param_name, tensor_or_param) 2656 2657 2658def _convert_to_params( 2659 tensors: List[Union[torch.Tensor, nn.Parameter]] 2660) -> List[nn.Parameter]: 2661 return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] 2662 2663 2664def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor: 2665 return ( 2666 param_or_tensor.detach() 2667 if isinstance(param_or_tensor, nn.Parameter) 2668 else param_or_tensor 2669 ) 2670 2671 2672def _get_aligned_numel(unsharded_dtype: torch.dtype): 2673 # NOTE: This alignment constraint comes from TorchInductor. 2674 ALIGNMENT = 16 # bytes 2675 unsharded_dtype_size = _get_dtype_size(unsharded_dtype) 2676 aligned_numel = ALIGNMENT // unsharded_dtype_size 2677 return aligned_numel 2678 2679 2680@functools.lru_cache(8) 2681def _get_dtype_size(dtype): 2682 return torch.empty((), dtype=dtype).element_size() 2683 2684 2685def _construct_padding_tensor( 2686 padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device 2687): 2688 # NOTE: Set the padding value as a magic number for debuggability. The 2689 # value itself should never be used in any user-facing computation. 2690 return ( 2691 torch.ones( 2692 (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device 2693 ) 2694 * _FLAT_PARAM_PADDING_VALUE 2695 ) 2696 2697 2698# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning 2699# messasge is passed in) 2700@functools.lru_cache(1) 2701def _warn_skip_writeback_check(log: logging.Logger, warning: str): 2702 logger.warning(warning) 2703 2704 2705# Use `lru_cache(1)` to only log the warning once 2706@functools.lru_cache(1) 2707def _warn_use_fake_all_gather(log: logging.Logger, warning: str): 2708 logger.warning(warning) 2709 2710 2711# Use `lru_cache(1)` to only log the warning once 2712@functools.lru_cache(1) 2713def _warn_use_fake_reduce(log: logging.Logger, warning: str): 2714 logger.warning(warning) 2715 2716 2717def _same_storage(a, b): 2718 # Params are DTensors in backward 2719 # with SHARD_GRAD_OP + TP 2720 from torch.distributed.tensor import DTensor 2721 2722 if isinstance(a, DTensor): 2723 a = a._local_tensor 2724 if isinstance(b, DTensor): 2725 b = b._local_tensor 2726 return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() 2727 2728 2729def _same_storage_size(a: torch.Tensor, b: int): 2730 return a.untyped_storage().size() // a.element_size() == b 2731 2732 2733def _storage_size_allocated(tensor: Tensor): 2734 storage_size: int = tensor.untyped_storage().size() 2735 return storage_size > 0 2736