1# mypy: allow-untyped-defs 2# Owner(s): ["oncall: distributed"] 3 4import contextlib 5import os 6import re 7import sys 8import warnings 9from abc import ABC, abstractmethod 10from contextlib import nullcontext 11from copy import deepcopy 12from enum import auto, Enum 13from functools import wraps 14from typing import ( 15 Any, 16 Callable, 17 Dict, 18 List, 19 no_type_check, 20 Optional, 21 Tuple, 22 Type, 23 Union, 24) 25from unittest import mock 26 27import torch 28import torch.distributed as dist 29import torch.nn as nn 30import torch.nn.functional as F 31from torch.distributed._composable import checkpoint 32from torch.distributed._composable.fsdp import fully_shard 33from torch.distributed._composable.fsdp._fsdp_param_group import ( 34 FSDPParamGroup, 35 RegisterPostBackwardFunction, 36) 37from torch.distributed.device_mesh import DeviceMesh 38from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP 39from torch.distributed.fsdp._common_utils import TrainingState 40from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES 41from torch.distributed.fsdp.fully_sharded_data_parallel import ( 42 BackwardPrefetch, 43 MixedPrecision, 44 ShardingStrategy, 45) 46from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 47from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap 48from torch.distributed.tensor import distribute_tensor, DTensor, Shard 49from torch.distributed.tensor.parallel import ( 50 ColwiseParallel, 51 parallelize_module, 52 RowwiseParallel, 53 SequenceParallel, 54) 55from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer 56from torch.nn.parallel.distributed import DistributedDataParallel as DDP 57from torch.testing._internal.common_distributed import ( 58 MultiProcessTestCase, 59 MultiThreadedTestCase, 60 run_subtests, 61 TEST_SKIPS, 62) 63from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms 64from torch.utils._triton import has_triton 65 66 67class FSDPInitMode(Enum): 68 # No FSDP wrapping 69 NO_FSDP = auto() 70 # FSDP recursive wrapping 71 RECURSIVE = auto() 72 # TODO: FSDP non-recursive wrapping 73 # NONRECURSIVE = auto() 74 75 76class CUDAInitMode(Enum): 77 # Move model to CUDA before passing to the FSDP constructor 78 CUDA_BEFORE = auto() 79 # Move model to CUDA after passing to the FSDP constructor 80 CUDA_AFTER = auto() 81 # Keep on CPU 82 CUDA_NEVER = auto() 83 84 85class FSDPTestModel(nn.Module, ABC): 86 """This defines the interface expected from all models used commonly for 87 FSDP unit tests.""" 88 89 @abstractmethod 90 def get_input(self, device) -> Tuple[torch.Tensor, ...]: 91 """Returns an input for the model as as tuple.""" 92 ... 93 94 @abstractmethod 95 def get_loss(self, input, output) -> torch.Tensor: 96 """Returns the loss given the input and output.""" 97 ... 98 99 @abstractmethod 100 def run_backward(self, loss) -> None: 101 """Runs the backward pass (e.g. including ``loss.backward()``).""" 102 ... 103 104 @staticmethod 105 @abstractmethod 106 def init(*args: Any, **kwargs: Any) -> nn.Module: 107 """Initializes an instance of this model.""" 108 ... 109 110 111def _assert_module_states( 112 model: nn.Module, 113 process_group: dist.ProcessGroup, 114 assert_fn: Callable, 115): 116 """ 117 All-gathers module states across ranks and calls ``assert_fn`` on each pair 118 of corresponding states from rank 0 and a nonzero rank. For example, if 119 ``assert_fn`` is ``self.assertEqual()``, then this checks that all module 120 states are equal across ranks. 121 """ 122 # Include names for debugging convenience 123 named_module_states = [ 124 (param_name, param.detach().cpu()) 125 for param_name, param in model.named_parameters() 126 ] 127 named_module_states += [ 128 (buffer_name, buffer.detach().cpu()) 129 for buffer_name, buffer in model.named_buffers() 130 ] 131 world_size = dist.get_world_size(process_group) 132 olist = [None for _ in range(world_size)] 133 dist.all_gather_object(olist, named_module_states, group=process_group) 134 rank0_states = olist[0] 135 assert rank0_states is not None # mypy 136 for state in olist[1:]: 137 assert state is not None # mypy 138 for (_, p1), (_, p2) in zip(rank0_states, state): 139 assert_fn(p1, p2) 140 141 142def _zero_model( 143 model: nn.Module, 144 zero_buffers: bool = False, 145 summon_full=True, 146): 147 """Zeros the parameters and optionally buffers of ``model`` in place.""" 148 ctx = FSDP.summon_full_params(model) if summon_full else nullcontext() 149 with ctx: 150 for param in model.parameters(): 151 with torch.no_grad(): 152 param.zero_() 153 if zero_buffers: 154 for buffer in model.buffers(): 155 with torch.no_grad(): 156 buffer.zero_() 157 158 159def _get_state_dict(model, cpu_offload=False, half=False): 160 if not cpu_offload: 161 model = model.cuda() 162 if half: 163 model.half() 164 165 return model.state_dict() 166 167 168def subtest_name(test_name_mapping, *args): 169 return "_".join( 170 [test_name_mapping[str(s)] if s is not None else "none" for s in args] 171 ) 172 173 174def _broadcast_state_dict(rank, state_dict): 175 # For non-FSDP roots, some parts of the model state on rank 0 may 176 # not be on CPU, so we move everything to CPU to avoid issues like: 177 # https://github.com/pytorch/pytorch/issues/77113. 178 for param_name, param in state_dict.items(): 179 if param.device != torch.device("cpu"): 180 state_dict[param_name] = param.cpu() 181 182 olist = [state_dict if rank == 0 else None] 183 dist.broadcast_object_list(olist) 184 state_dict = olist[0] 185 # Ensure that the state is on CUDA 186 for param_name in state_dict.keys(): 187 state_dict[param_name] = state_dict[param_name].cuda() 188 return state_dict 189 190 191def get_full_params(model: nn.Module, recurse: bool = True): 192 """ 193 Returns the full unsharded parameters of ``model``. Any FSDP-managed 194 parameters offloaded to CPU are moved to GPU in the returned list. 195 196 Args: 197 recurse (bool): If ``False``, only unshards the parameters immediate to 198 ``model``; if ``True``, recurses through the module hierarchy 199 rooted at ``model``. 200 """ 201 with FSDP.summon_full_params(model, recurse=recurse): 202 return deepcopy(list(model.parameters())) 203 204 205def _maybe_cuda(model: nn.Module, move_to_cuda: bool): 206 return model.cuda() if move_to_cuda else model 207 208 209def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs): 210 return model if not wrap_fsdp else FSDP(model, *args, **kwargs) 211 212 213class DummyProcessGroup: 214 def __init__(self, rank: int, size: int): 215 self._rank = rank 216 self._size = size 217 218 def rank(self) -> int: 219 return self._rank 220 221 def size(self) -> int: 222 return self._size 223 224 def allreduce(self, *args, **kwargs): 225 dist_wait = mock.Mock() 226 227 def get_future(): 228 future: torch.futures.Future = torch.futures.Future() 229 future.set_result(1) 230 return future 231 232 dist_wait.get_future = get_future 233 return dist_wait 234 235 236class TransformerWithSharedParams(FSDPTestModel): 237 def __init__( 238 self, 239 group: dist.ProcessGroup, 240 cuda_init_mode: CUDAInitMode, 241 add_bn: bool, 242 deterministic: bool, 243 ): 244 super().__init__() 245 self.rank = group.rank() 246 self.world_size = group.size() 247 if deterministic: 248 torch.manual_seed(0) 249 d_vocab = 23 250 d_model = 16 251 252 self.embed_tokens = nn.Embedding(d_vocab, d_model) 253 self.transformer = nn.Transformer( 254 d_model=d_model, 255 num_encoder_layers=2, 256 num_decoder_layers=2, 257 dim_feedforward=8, 258 dropout=0.1, 259 ) 260 self.output_proj = nn.Linear(d_model, d_vocab) 261 262 # share the embedding and output projection weights 263 self.output_proj.weight = self.embed_tokens.weight 264 self.register_buffer( 265 "vocab_bias", self.embed_tokens.weight.new_ones((d_model,)) 266 ) 267 self.register_buffer( 268 "long_buffer", 269 torch.zeros_like(self.vocab_bias, dtype=torch.long), 270 ) # type: ignore[arg-type] 271 272 self.bs = 2 273 self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity() 274 if cuda_init_mode == CUDAInitMode.CUDA_BEFORE: 275 self = self.cuda() 276 if deterministic: 277 self.eval() 278 279 def get_input(self, device): 280 torch.manual_seed(1 + self.rank) # keep everything deterministic 281 src = torch.arange(12, device=device).view(6, self.bs) # T x B 282 tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B 283 return (src, tgt) 284 285 def forward(self, src_ids, tgt_ids): 286 src = self.embed_tokens(src_ids) 287 src = src + self.vocab_bias + self.long_buffer.type_as(src) # type: ignore[operator] 288 tgt = self.embed_tokens(tgt_ids) 289 tgt = self.bn(tgt) 290 x = self.transformer(src, tgt) 291 return self.output_proj(x) 292 293 def get_loss(self, input, output): 294 _, tgt = input 295 return nn.functional.cross_entropy( 296 output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum" 297 ) 298 299 def run_backward(self, loss): 300 loss.backward() 301 302 @staticmethod 303 def init( 304 group: dist.ProcessGroup, 305 fsdp_init_mode: FSDPInitMode, 306 cuda_init_mode: CUDAInitMode, 307 fsdp_kwargs: Optional[Dict[str, Any]] = None, 308 deterministic: bool = False, 309 add_bn: bool = True, 310 ) -> Union[nn.Module, FSDP]: 311 """ 312 Initializes a :class:`TransformerWithSharedParams` instance. 313 314 Args: 315 fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap 316 any modules with FSDP. If ``RECURSIVE``, then wraps with 317 top-level FSDP. By default, the top-level FSDP uses the 318 ``ModuleWrapPolicy`` for encoder and decoder layers, but a 319 different auto wrap policy may be specified via 320 ``fsdp_kwargs``. 321 cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. 322 fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments 323 forwarded to the FSDP constructor. 324 deterministic (bool): Whether to make the model deterministic 325 across constructions. 326 add_bn (bool): Whether to include batch norm in the model. 327 """ 328 329 if fsdp_kwargs is None: 330 fsdp_kwargs = {} 331 if fsdp_init_mode == FSDPInitMode.NO_FSDP: 332 if isinstance(group, tuple): 333 pg = group[0] 334 else: 335 pg = group 336 return TransformerWithSharedParams( 337 pg, cuda_init_mode, add_bn, deterministic 338 ) 339 elif fsdp_init_mode == FSDPInitMode.RECURSIVE: 340 # Default to the `ModuleWrapPolicy` 341 if "auto_wrap_policy" not in fsdp_kwargs: 342 auto_wrap_policy = ModuleWrapPolicy( 343 { 344 TransformerEncoderLayer, 345 TransformerDecoderLayer, 346 } 347 ) 348 else: 349 auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") 350 351 if ( 352 "sharding_strategy" in fsdp_kwargs 353 and fsdp_kwargs["sharding_strategy"] 354 in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2} 355 and not isinstance(group, tuple) 356 ): 357 fsdp_pg = None 358 else: 359 fsdp_pg = group 360 361 if isinstance(group, tuple): 362 tformer_pg = group[0] 363 else: 364 tformer_pg = group 365 366 m = TransformerWithSharedParams( 367 tformer_pg, cuda_init_mode, add_bn, deterministic 368 ) 369 fsdp_model = FSDP( 370 m, 371 fsdp_pg, 372 auto_wrap_policy=auto_wrap_policy, 373 **fsdp_kwargs, 374 ) 375 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 376 fsdp_model = fsdp_model.cuda() 377 return fsdp_model 378 raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") 379 380 def get_ignored_modules(self): 381 return [self.transformer] 382 383 384class NestedWrappedModule(FSDPTestModel): 385 def __init__( 386 self, 387 group: dist.ProcessGroup, 388 wrap_fsdp: bool, 389 cuda_init_mode: CUDAInitMode, 390 deterministic: bool, 391 **fsdp_kwargs, 392 ): 393 super().__init__() 394 self.rank = group.rank() 395 self.world_size = group.size() 396 move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE 397 398 def _maybe_wrap(layer): 399 if wrap_fsdp: 400 return FSDP(layer, group, **fsdp_kwargs) 401 return layer 402 403 if deterministic: 404 torch.manual_seed(0) 405 self.module = nn.Sequential( 406 _maybe_cuda(nn.Linear(8, 4), move_to_cuda), 407 _maybe_wrap( 408 nn.Sequential( 409 _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), 410 _maybe_cuda(nn.Linear(16, 16), move_to_cuda), 411 ), 412 ), 413 _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), 414 _maybe_cuda(nn.Linear(4, 8), move_to_cuda), 415 ) 416 417 def get_input(self, device): 418 torch.manual_seed(1 + self.rank) # keep everything deterministic 419 return (torch.rand(4, 8, device=device),) 420 421 def forward(self, x): 422 return self.module(x) 423 424 def get_loss(self, input, output): 425 loss = output.sum() 426 return loss 427 428 def run_backward(self, loss): 429 loss.backward() 430 431 @staticmethod 432 def init( 433 group: dist.ProcessGroup, 434 fsdp_init_mode: FSDPInitMode, 435 cuda_init_mode: CUDAInitMode, 436 fsdp_kwargs: Optional[Dict[str, Any]] = None, 437 deterministic: bool = False, 438 ) -> nn.Module: 439 """ 440 Initializes a :class:`NestedWrappedModule` instance. 441 442 Args: 443 fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap 444 any modules with FSDP. If ``RECURSIVE``, then wraps some nested 445 modules with FSDP but not the top-level module. The model may 446 later be wrapped with a top-level FSDP external to this method 447 if desired. 448 cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. 449 fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments 450 forwarded to the FSDP constructor. 451 deterministic (bool): Whether to make the model deterministic 452 across constructions. 453 """ 454 if fsdp_kwargs is None: 455 fsdp_kwargs = {} 456 if fsdp_init_mode == FSDPInitMode.NO_FSDP: 457 return NestedWrappedModule( 458 group, 459 wrap_fsdp=False, 460 cuda_init_mode=cuda_init_mode, 461 deterministic=deterministic, 462 ) 463 elif fsdp_init_mode == FSDPInitMode.RECURSIVE: 464 # Does not wrap with top-level FSDP 465 fsdp_model = NestedWrappedModule( 466 group, 467 wrap_fsdp=True, 468 cuda_init_mode=cuda_init_mode, 469 deterministic=deterministic, 470 **fsdp_kwargs, 471 ) 472 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 473 fsdp_model = fsdp_model.cuda() 474 return fsdp_model 475 raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") 476 477 478class AlwaysWrapNestedWrappedModule(NestedWrappedModule): 479 @staticmethod 480 def init( 481 group: dist.ProcessGroup, 482 fsdp_init_mode: FSDPInitMode, 483 cuda_init_mode: CUDAInitMode, 484 fsdp_kwargs: Optional[Dict[str, Any]] = None, 485 deterministic: bool = False, 486 ): 487 """ 488 Initializes a :class:`NestedWrappedModule` instance, but unlike 489 :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this 490 wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap 491 policy. 492 """ 493 model = super( 494 AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule 495 ).init( 496 group=group, 497 fsdp_init_mode=FSDPInitMode.NO_FSDP, 498 cuda_init_mode=cuda_init_mode, 499 fsdp_kwargs=fsdp_kwargs, 500 deterministic=deterministic, 501 ) 502 if fsdp_init_mode == FSDPInitMode.NO_FSDP: 503 return model 504 elif fsdp_init_mode == FSDPInitMode.RECURSIVE: 505 fsdp_kwargs = fsdp_kwargs or {} 506 fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) 507 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 508 fsdp_model = fsdp_model.cuda() 509 return fsdp_model 510 511 512class NonUniformReqGradNWM(NestedWrappedModule): 513 def __init__( 514 self, 515 group: dist.ProcessGroup, 516 wrap_fsdp: bool, 517 cuda_init_mode: CUDAInitMode, 518 deterministic: bool, 519 **fsdp_kwargs, 520 ): 521 super(NestedWrappedModule, self).__init__() 522 # This `__init__` only differs from `NestedWrappedModule.__init__` in that 523 # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential` 524 # container. This arrangement results in all elements of the last two parameters 525 # residing on a single rank. Freezing all parameters except those two allows us 526 # to verify that `ShardedGradScaler` accommodates situations where some ranks 527 # have no (non-zero sized) parameter shards. 528 self.rank = group.rank() 529 self.world_size = group.size() 530 move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE 531 532 def _maybe_wrap(layer): 533 if wrap_fsdp: 534 return FSDP(layer, group, **fsdp_kwargs) 535 return layer 536 537 if deterministic: 538 torch.manual_seed(0) 539 self.module = nn.Sequential( 540 _maybe_cuda(nn.Linear(8, 4), move_to_cuda), 541 _maybe_wrap( 542 nn.Sequential( 543 _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), 544 _maybe_cuda(nn.Linear(16, 16), move_to_cuda), 545 ), 546 ), 547 _maybe_wrap( 548 nn.Sequential( 549 _maybe_cuda(nn.Linear(16, 4), move_to_cuda), 550 _maybe_cuda(nn.Linear(4, 8), move_to_cuda), 551 ), 552 ), 553 ) 554 555 @staticmethod 556 def _set_nonuniform_req_grad(model, req_grad_mask) -> None: 557 for n, p in model.named_parameters(): 558 if not re.match(req_grad_mask, n): 559 p.requires_grad_(False) 560 561 @staticmethod 562 def init( 563 group: dist.ProcessGroup, 564 fsdp_init_mode: FSDPInitMode, 565 cuda_init_mode: CUDAInitMode, 566 fsdp_kwargs: Optional[Dict[str, Any]] = None, 567 deterministic: bool = False, 568 ): 569 """ 570 Initializes a :class:`NestedWrappedModule` instance, but unlike 571 :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential` 572 container to enable the desired non-uniform ``requires_grad`` 573 ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP`` 574 init modes, freezes all parameters except the last two to validate 575 ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in 576 FSDP ``use_orig_params=True`` mode. 577 """ 578 # The parameters that should remain unfrozen are in `module.2.1`. The regex 579 # pattern below matches the relevant parameter names both with and without 580 # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present. 581 req_grad_pattern = re.compile(r"module\.2.*\.1.*") 582 if fsdp_init_mode == FSDPInitMode.NO_FSDP: 583 ddp_model = NonUniformReqGradNWM( 584 group, 585 wrap_fsdp=False, 586 cuda_init_mode=cuda_init_mode, 587 deterministic=deterministic, 588 ) 589 NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern) 590 return ddp_model 591 elif fsdp_init_mode == FSDPInitMode.RECURSIVE: 592 if fsdp_kwargs is None: 593 fsdp_kwargs = {} 594 fsdp_model = NonUniformReqGradNWM( 595 group, 596 wrap_fsdp=True, 597 cuda_init_mode=cuda_init_mode, 598 deterministic=deterministic, 599 **fsdp_kwargs, 600 ) 601 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 602 fsdp_model = fsdp_model.cuda() 603 NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern) 604 return fsdp_model 605 raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") 606 607 608class ModuleWithDelay(FSDPTestModel): 609 """This class wraps a :class:`FSDPTestModel` to optionally add a delay 610 after computing the loss and/or before the gradient reduction.""" 611 612 def __init__( 613 self, 614 module: nn.Module, 615 delay_after_loss_ms: int, 616 delay_before_reduction_ms: int, 617 ): 618 super().__init__() 619 self.delay_after_loss_ms = delay_after_loss_ms 620 self.delay_before_reduction_ms = delay_before_reduction_ms 621 self.module = module 622 623 def get_input(self, device): 624 return self.module.get_input(device) 625 626 def forward(self, x): 627 return self.module(x) 628 629 def get_loss(self, input, output): 630 loss = self.module.get_loss(input, output) 631 if self.delay_after_loss_ms > 0: 632 torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms())) 633 return loss 634 635 def run_backward(self, loss): 636 orig_reduce_scatter = torch.distributed.reduce_scatter_tensor 637 638 def _delayed_reduce_scatter(*args, **kwargs): 639 if self.delay_before_reduction_ms > 0: 640 torch.cuda._sleep( 641 int(self.delay_before_reduction_ms * get_cycles_per_ms()) 642 ) 643 return orig_reduce_scatter(*args, **kwargs) 644 645 with mock.patch( 646 "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter 647 ): 648 self.module.run_backward(loss) 649 650 @staticmethod 651 def init( 652 module_class: Type[FSDPTestModel], 653 *model_args: Any, 654 delay_after_loss_ms: int, 655 delay_before_reduction_ms: int, 656 **model_kwargs: Any, 657 ): 658 """ 659 Args: 660 module_class (Type[FSDPTestModel]): Wrapped module class to which 661 to add delays. 662 model_args: Positional arguments forwarded to the ``module_class`` 663 ``init()``. 664 delay_after_loss_ms (int): Delay after computing the loss/before 665 the optimizer step (in ms). 666 delay_before_reduction_ms (int): Delay before reduce-scattering 667 gradients (in ms). 668 model_kwargs: Keyword arguments forwarded to the ``module_class`` 669 ``init()``. 670 """ 671 return ModuleWithDelay( 672 module_class.init(*model_args, **model_kwargs), 673 delay_after_loss_ms, 674 delay_before_reduction_ms, 675 ) 676 677 678class NestedWrappedModuleWithDelay(ModuleWithDelay): 679 @staticmethod 680 def init( # type: ignore[override] 681 group: dist.ProcessGroup, 682 fsdp_init_mode: FSDPInitMode, 683 cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER, 684 fsdp_kwargs: Optional[Dict[str, Any]] = None, 685 deterministic: bool = False, 686 delay_after_loss_ms: int = 0, 687 delay_before_reduction_ms: int = 0, 688 ): 689 return ModuleWithDelay.init( 690 NestedWrappedModule, 691 group=group, 692 fsdp_init_mode=fsdp_init_mode, 693 cuda_init_mode=cuda_init_mode, 694 fsdp_kwargs=fsdp_kwargs, 695 deterministic=deterministic, 696 delay_after_loss_ms=delay_after_loss_ms, 697 delay_before_reduction_ms=delay_before_reduction_ms, 698 ) 699 700 701class DummyDDP(nn.Module): 702 def __init__(self, module): 703 super().__init__() 704 self.module = module 705 706 def forward(self, *args, **kwargs): 707 return self.module(*args, **kwargs) 708 709 710class MixtureOfExperts(NestedWrappedModule): 711 def __init__( 712 self, 713 group: dist.ProcessGroup, 714 wrap_fsdp: bool, 715 cuda_init_mode: CUDAInitMode, 716 delay_before_free_ms: int, 717 deterministic: bool, 718 **fsdp_kwargs, 719 ): 720 super().__init__( 721 group=group, 722 wrap_fsdp=wrap_fsdp, 723 cuda_init_mode=cuda_init_mode, 724 deterministic=deterministic, 725 ) 726 self.group = group 727 self.delay_before_free_ms = delay_before_free_ms 728 self.wrap_fsdp = wrap_fsdp 729 self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE 730 if deterministic: 731 # Give each rank different expert parameters 732 torch.manual_seed(42 + self.rank) 733 d_expert = 23 734 d_shared = 12 735 d_input = 8 736 expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda) 737 738 self.num_expert_params = sum(p.numel() for p in expert.parameters()) 739 for p in expert.parameters(): 740 p.expert = True # type: ignore[attr-defined] 741 742 if deterministic: 743 # Keep all other parameters the same across ranks 744 torch.manual_seed(0) 745 746 shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda) 747 748 if wrap_fsdp: 749 # we create a process group of size 1 for the expert params 750 expert_group = torch.distributed.new_group( 751 [group.rank()] 752 ) # world size 1 means no shard 753 expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment] 754 shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment] 755 756 self.module = nn.Sequential( 757 _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda), 758 shared, 759 expert, 760 _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda), 761 ) 762 763 def forward(self, x): 764 if self.delay_before_free_ms > 0: 765 expert = self.module[2] 766 if isinstance(expert, FSDP): 767 orig_reshard = torch.distributed.fsdp._runtime_utils._reshard 768 769 def _delayed_reshard(*args, **kwargs): 770 torch.cuda._sleep( 771 int(self.delay_before_free_ms * get_cycles_per_ms()) 772 ) 773 return orig_reshard(*args, **kwargs) 774 775 # This patch covers any `import torch..._reshard` uses. 776 with mock.patch( 777 "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard 778 ): 779 return self.module(x) 780 781 return self.module(x) 782 783 def run_backward(self, loss): 784 loss.backward() 785 # Manually reduce gradients if not wrapped in FullyShardedDataParallel 786 if not self.wrap_fsdp: 787 with torch.no_grad(): 788 for p in self.parameters(): 789 if hasattr(p, "expert"): 790 continue # these params don't need grad reduction 791 if p.grad is not None: 792 p.grad.div_(self.world_size) 793 torch.distributed.all_reduce(p.grad, group=self.group) 794 795 @staticmethod 796 def init( 797 group: dist.ProcessGroup, 798 fsdp_init_mode: FSDPInitMode, 799 cuda_init_mode: CUDAInitMode, 800 fsdp_kwargs: Optional[Dict[str, Any]] = None, 801 deterministic: bool = False, 802 delay_before_free_ms: int = 0, 803 ): 804 """ 805 Initializes a :class:`MixtureOfExperts` instance. 806 807 Args: 808 fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap 809 any modules with FSDP. If ``RECURSIVE``, then wraps some nested 810 modules with FSDP, including the expert and shared layers, but 811 not the top-level module. The model may later be wrapped with a 812 top-level FSDP external to this method if desired. 813 cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. 814 fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments 815 forwarded to the FSDP constructor. 816 deterministic (bool): Whether to make the model deterministic 817 across constructions. 818 delay_before_free_ms (int): Delay before resharding expert 819 parameters in the forward pass (in ms). 820 """ 821 if fsdp_kwargs is None: 822 fsdp_kwargs = {} 823 if fsdp_init_mode == FSDPInitMode.NO_FSDP: 824 return MixtureOfExperts( 825 group, 826 wrap_fsdp=False, 827 cuda_init_mode=cuda_init_mode, 828 delay_before_free_ms=delay_before_free_ms, 829 deterministic=deterministic, 830 ) 831 elif fsdp_init_mode == FSDPInitMode.RECURSIVE: 832 # Does not wrap with top-level FSDP 833 fsdp_model = MixtureOfExperts( 834 group, 835 wrap_fsdp=True, 836 cuda_init_mode=cuda_init_mode, 837 delay_before_free_ms=delay_before_free_ms, 838 deterministic=deterministic, 839 **fsdp_kwargs, 840 ) 841 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 842 fsdp_model = fsdp_model.cuda() 843 return fsdp_model 844 raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}") 845 846 847class MLP(nn.Module): 848 def __init__( 849 self, 850 dim: int, 851 device: Optional[torch.device] = None, 852 *, 853 bias: bool = True, 854 with_buffer: bool = False, 855 dim_multiplier: int = 4, 856 ): 857 super().__init__() 858 self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias) 859 self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias) 860 if with_buffer: 861 self.register_buffer("buffer", torch.randn((dim,), device=device)) 862 else: 863 self.buffer = None 864 865 def forward(self, x: torch.Tensor) -> torch.Tensor: 866 z = self.in_proj(x) 867 z = F.relu(z) 868 z = self.out_proj(z) 869 z = F.relu(z) 870 if self.buffer is not None: 871 z = z + self.buffer 872 return z 873 874 def reset_parameters(self): 875 if self.buffer is not None: 876 torch.nn.init.normal_(self.buffer) 877 878 879class MLPStack(nn.Sequential): 880 def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False): 881 modules: List[nn.Module] = [ 882 # Use multiplier of 3 to exercise uneven case 883 MLP(mlp_dim, dim_multiplier=3), 884 MLP(mlp_dim), 885 MLP(mlp_dim, dim_multiplier=3), 886 ] 887 if with_seq_parallel: 888 modules.append(nn.LayerNorm(mlp_dim, bias=False)) 889 super().__init__(*modules) 890 self.with_seq_parallel = with_seq_parallel 891 892 def parallelize( 893 self, 894 tp_mesh: DeviceMesh, 895 dp_mesh: DeviceMesh, 896 use_activation_checkpointing: bool, 897 **fsdp_kwargs, 898 ) -> "MLPStack": 899 parallelize_plan = { 900 # Pass `use_local_output=False` to keep as DTensor to preserve 901 # uneven activation dims 902 "0.in_proj": ColwiseParallel(use_local_output=False), 903 "0.out_proj": RowwiseParallel(use_local_output=False), 904 "1.in_proj": ColwiseParallel(use_local_output=False), 905 "1.out_proj": RowwiseParallel(use_local_output=False), 906 "2.in_proj": ColwiseParallel(use_local_output=False), 907 "2.out_proj": RowwiseParallel(output_layouts=Shard(1)) 908 if self.with_seq_parallel 909 else RowwiseParallel(), 910 } 911 if self.with_seq_parallel: 912 parallelize_plan["3"] = SequenceParallel(sequence_dim=1) 913 parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan) 914 for module in self: 915 if isinstance(module, nn.LayerNorm): 916 continue 917 if use_activation_checkpointing: 918 checkpoint(module) 919 fully_shard(module, mesh=dp_mesh, **fsdp_kwargs) 920 fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) 921 return self 922 923 924class DoubleLinear(nn.Module): 925 """ 926 This can be used for returning multiple outputs from a module 927 (``use_second_linear=True``) or for having an unused module (``False``). 928 """ 929 930 def __init__(self, dim: int, use_second_linear: bool = True): 931 super().__init__() 932 self.lin1 = nn.Linear(dim, dim) 933 self.lin2 = nn.Linear(dim, dim) 934 self.relu = nn.ReLU() 935 self.use_second_linear = use_second_linear 936 937 def forward( 938 self, x: torch.Tensor 939 ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 940 if self.use_second_linear: 941 return self.relu(self.lin1(x)), self.relu(self.lin2(x)) 942 return self.relu(self.lin1(x)) 943 944 945# NOTE: For these patch methods, if we want safety under multi-threading (e.g. 946# when using multi-threaded process group), then we want: 947# (1) a barrier immediately after reading the original value to ensure that all 948# threads see the same original value 949# (2) a barrier immediately before restoring the original value to ensure that 950# all threads use the patched value inside the context 951@contextlib.contextmanager 952def patch_all_gather(new_all_gather_into_tensor: Callable): 953 orig_all_gather = dist.all_gather_into_tensor 954 dist.barrier() 955 dist.all_gather_into_tensor = new_all_gather_into_tensor 956 try: 957 yield 958 finally: 959 dist.barrier() 960 dist.all_gather_into_tensor = orig_all_gather 961 962 963@contextlib.contextmanager 964def patch_reduce_scatter(new_reduce_scatter_tensor: Callable): 965 orig_reduce_scatter = dist.reduce_scatter_tensor 966 dist.barrier() 967 dist.reduce_scatter_tensor = new_reduce_scatter_tensor 968 try: 969 yield 970 finally: 971 dist.barrier() 972 dist.reduce_scatter_tensor = orig_reduce_scatter 973 974 975@contextlib.contextmanager 976def patch_all_reduce(new_all_reduce: Callable): 977 orig_all_reduce = dist.all_reduce 978 dist.barrier() 979 dist.all_reduce = new_all_reduce 980 try: 981 yield 982 finally: 983 dist.barrier() 984 dist.all_reduce = orig_all_reduce 985 986 987@no_type_check 988@contextlib.contextmanager 989def patch_unshard(new_unshard: Callable): 990 orig_unshard = FSDPParamGroup.unshard 991 dist.barrier() 992 FSDPParamGroup.unshard = new_unshard 993 try: 994 yield 995 finally: 996 dist.barrier() 997 FSDPParamGroup.unshard = orig_unshard 998 999 1000@no_type_check 1001@contextlib.contextmanager 1002def patch_reshard(new_reshard: Callable): 1003 orig_reshard = FSDPParamGroup.reshard 1004 dist.barrier() 1005 FSDPParamGroup.reshard = new_reshard 1006 try: 1007 yield 1008 finally: 1009 dist.barrier() 1010 FSDPParamGroup.reshard = orig_reshard 1011 1012 1013@no_type_check 1014@contextlib.contextmanager 1015def patch_post_backward(new_post_backward: Callable): 1016 orig_post_backward = FSDPParamGroup.post_backward 1017 dist.barrier() 1018 FSDPParamGroup.post_backward = new_post_backward 1019 try: 1020 yield 1021 finally: 1022 dist.barrier() 1023 FSDPParamGroup.post_backward = orig_post_backward 1024 1025 1026@no_type_check 1027@contextlib.contextmanager 1028def patch_register_post_backward_hook_backward(new_backward: Callable): 1029 orig_backward = RegisterPostBackwardFunction.backward 1030 dist.barrier() 1031 RegisterPostBackwardFunction.backward = new_backward 1032 try: 1033 yield 1034 finally: 1035 dist.barrier() 1036 RegisterPostBackwardFunction.backward = orig_backward 1037 1038 1039def reduce_scatter_with_assert( 1040 cls, 1041 orig_reduce_scatter: Callable, 1042 assert_fn: Callable, # `assert_fn(output: Tensor)` 1043 *args: Any, 1044 **kwargs: Any, 1045): 1046 if len(args) > 0: 1047 output = args[0] 1048 elif "output" in kwargs: 1049 output = kwargs["output"] 1050 else: 1051 raise AssertionError( 1052 f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}" 1053 ) 1054 assert_fn(output) 1055 return orig_reduce_scatter(*args, **kwargs) 1056 1057 1058def check_sharded_parity( 1059 cls, # unit test class 1060 replicated_module: nn.Module, 1061 sharded_module: nn.Module, 1062 prefixes_to_ignore: Tuple[str, ...] = (), 1063): 1064 for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip( 1065 replicated_module.named_parameters(), sharded_module.named_parameters() 1066 ): 1067 clean_sharded_name = sharded_name 1068 for prefix in prefixes_to_ignore: 1069 clean_sharded_name = clean_sharded_name.replace(prefix, "") 1070 cls.assertEqual(replicated_name, clean_sharded_name) 1071 cls.assertIsInstance(sharded_param, DTensor) 1072 assert isinstance(sharded_param, DTensor) # mypy 1073 mesh, placements = sharded_param.device_mesh, sharded_param.placements 1074 if tuple(placements) == (Shard(0), Shard(0)): 1075 raise AssertionError( 1076 "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), " 1077 "so we cannot check for equality using it" 1078 ) 1079 sharded_ref_param = distribute_tensor(replicated_param, mesh, placements) 1080 cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local()) 1081 if replicated_param.grad is None: 1082 cls.assertIsNone(sharded_param.grad) 1083 continue 1084 cls.assertIsNotNone(sharded_param.grad) 1085 sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements) 1086 cls.assertIsInstance(sharded_param.grad, DTensor) 1087 assert isinstance(sharded_param.grad, DTensor) # mypy 1088 cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local()) 1089 1090 1091class FSDPTestMultiThread(MultiThreadedTestCase): 1092 @property 1093 def world_size(self): 1094 return torch.cuda.device_count() if torch.cuda.is_available() else 4 1095 1096 def setUp(self): 1097 super().setUp() 1098 self._spawn_threads() 1099 1100 def run_subtests(self, *args, **kwargs): 1101 return run_subtests(self, *args, **kwargs) 1102 1103 def perThreadSetUp(self): 1104 torch._dynamo.reset() 1105 1106 def perThreadTearDown(self): 1107 torch._dynamo.reset() 1108 1109 1110class FSDPTest(MultiProcessTestCase): 1111 def setUp(self): 1112 super().setUp() 1113 # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`, 1114 # which can cause unit test flakiness: 1115 # https://github.com/pytorch/pytorch/issues/90848 1116 os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0" 1117 self._spawn_processes() 1118 1119 @property 1120 def world_size(self): 1121 return min(torch.cuda.device_count(), 8) if torch.cuda.is_available() else 4 1122 1123 @property 1124 def process_group(self): 1125 return dist.distributed_c10d._get_default_group() 1126 1127 @property 1128 def init_method(self): 1129 return f"{FILE_SCHEMA}{self.file_name}" 1130 1131 def _check_cpu_offload(self, fsdp_model, cpu_offload): 1132 self.assertEqual(cpu_offload, fsdp_model.cpu_offload) 1133 1134 def _check_backward_prefetch(self, fsdp_model, backward_prefetch): 1135 self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch) 1136 1137 def _check_forward_prefetch(self, fsdp_model, forward_prefetch): 1138 self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch) 1139 1140 def run_subtests(self, *args, **kwargs): 1141 return run_subtests(self, *args, **kwargs) 1142 1143 @classmethod 1144 def _run(cls, rank, test_name, file_name, pipe, **kwargs): 1145 self = cls(test_name) 1146 self.rank = rank 1147 self.file_name = file_name 1148 fake_pg = kwargs.get("fake_pg", False) 1149 1150 print(f"dist init r={self.rank}, world={self.world_size}") 1151 1152 # Specify gloo backend to make 'init_process_group()' succeed, 1153 # Actual tests will be skipped if there is no enough GPUs. 1154 backend = "nccl" if torch.cuda.is_available() else "gloo" 1155 1156 try: 1157 if fake_pg: 1158 store = torch.testing._internal.distributed.fake_pg.FakeStore() 1159 dist.init_process_group( 1160 backend="fake", 1161 world_size=self.world_size, 1162 rank=rank, 1163 store=store, 1164 ) 1165 else: 1166 dist.init_process_group( 1167 init_method=self.init_method, 1168 backend=backend, 1169 world_size=int(self.world_size), 1170 rank=self.rank, 1171 ) 1172 except RuntimeError as e: 1173 if "recompile" in e.args[0]: 1174 sys.exit(TEST_SKIPS["backend_unavailable"].exit_code) 1175 1176 raise 1177 1178 device_ids = None 1179 if torch.cuda.is_available() and torch.cuda.device_count(): 1180 device_id = self.rank % torch.cuda.device_count() 1181 torch.cuda.set_device(device_id) 1182 device_ids = [device_id] 1183 1184 # Execute barrier prior to running test to ensure that every process 1185 # has finished initialization and that the following test 1186 # immediately exiting due to a skip doesn't cause flakiness. 1187 dist.barrier(device_ids=device_ids) 1188 1189 torch._dynamo.reset() 1190 self.run_test(test_name, pipe) 1191 torch._dynamo.reset() 1192 1193 dist.barrier(device_ids=device_ids) 1194 1195 dist.destroy_process_group() 1196 1197 def _train_for_several_steps( 1198 self, 1199 model: nn.Module, 1200 num_steps: int, 1201 autocast: bool, 1202 lr: float = 0.01, 1203 fsdp_cpu_offload: Optional[CPUOffload] = None, 1204 save_model: bool = False, 1205 mixed_precision: Optional[MixedPrecision] = None, 1206 enable_sharded_grad_scaler: bool = False, 1207 use_pure_fp16: bool = False, 1208 sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None, 1209 ): 1210 cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params 1211 1212 model_device = next(model.parameters()).device 1213 if sharded_grad_scaler_kwargs is None: 1214 sharded_grad_scaler_kwargs = {} 1215 sharded_grad_scaler = ShardedGradScaler( 1216 enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs 1217 ) 1218 # use SGD with momentum instead of Adam, since Adam is scale invariant 1219 # and this makes it bad for tests 1220 optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 1221 for _ in range(num_steps): 1222 optim.zero_grad() 1223 with torch.amp.autocast("cuda", enabled=autocast): 1224 # Inputs always cuda regardless of cpu offloading, or model.device 1225 input = model.module.get_input(torch.device("cuda")) 1226 if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)): 1227 if isinstance(input, torch.Tensor): 1228 input = input.half() 1229 else: 1230 input = tuple(x.half() for x in input) 1231 output = model(*input) 1232 # Post-forward, if CPU offloading model param should be on CPU. 1233 if ( 1234 cpu_offload_params 1235 and isinstance(model, FSDP) 1236 # If not resharding after forward, the parameters are still 1237 # exposed as unsharded views into the GPU flat parameter 1238 and model.sharding_strategy 1239 not in NO_RESHARD_AFTER_FORWARD_STRATEGIES 1240 ): 1241 for p in model.parameters(): 1242 # Params should always be on CPU 1243 self.assertEqual(p.device, torch.device("cpu")) 1244 1245 loss = model.module.get_loss(input, output).to(model_device) 1246 loss = sharded_grad_scaler.scale(loss) 1247 1248 if not mixed_precision and not use_pure_fp16: 1249 assert ( 1250 loss.dtype == torch.float32 1251 ), "loss data type should be float32, as the original \ 1252 parameter data type is float32." 1253 else: 1254 if use_pure_fp16: 1255 self.assertEqual(loss.dtype, torch.float16) 1256 # FSDP loss is fp16, DDP AMP loss is fp32 1257 elif isinstance(model, FSDP): 1258 assert mixed_precision is not None # mypy 1259 self.assertEqual(loss.dtype, mixed_precision.param_dtype) 1260 else: 1261 self.assertEqual(loss.dtype, torch.float32) 1262 model.module.run_backward(loss) 1263 # Post-backward, if CPU offloading model params should be on CPU. 1264 if cpu_offload_params and isinstance(model, FSDP): 1265 for p in model.parameters(): 1266 # Params should always be on CPU 1267 self.assertEqual(p.device, torch.device("cpu")) 1268 # Unscale the gradients and step 1269 sharded_grad_scaler.step(optim) 1270 # Update the scale factor 1271 sharded_grad_scaler.update() 1272 # if save_model, simulate save + load. 1273 if save_model: 1274 state_dict = {k: v.clone() for k, v in model.state_dict().items()} 1275 # Zero params, if save/load state_dict did not work properly, this 1276 # would break the parity test with DDP. 1277 _zero_model(model) 1278 model.load_state_dict(state_dict) 1279 1280 if isinstance(model, FSDP): 1281 model._assert_state(TrainingState.IDLE) 1282 return loss.detach() # type: ignore[possibly-undefined] 1283 1284 def _test_fsdp_parity( 1285 self, 1286 model_class: Type[FSDPTestModel], 1287 fsdp_init_mode: FSDPInitMode, 1288 cuda_init_mode: CUDAInitMode, 1289 ref_init_fn: Optional[Callable] = None, 1290 num_iters: int = 2, 1291 save_model: bool = True, 1292 cpu_offload: CPUOffload = CPUOffload(), 1293 backward_prefetch: Optional[BackwardPrefetch] = None, 1294 sharding_strategy: Optional[ShardingStrategy] = None, 1295 mixed_precision: Optional[MixedPrecision] = None, 1296 forward_prefetch: bool = False, 1297 use_orig_params: bool = False, 1298 enable_sharded_grad_scaler: bool = False, 1299 use_pure_fp16: bool = False, 1300 init_kwargs: Optional[Dict[str, Any]] = None, 1301 sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None, 1302 **fsdp_kwargs, 1303 ): 1304 """ 1305 Tests FSDP training against a reference, which defaults to DDP but 1306 may be customized with ``ref_init_fn``. 1307 1308 Args: 1309 model_class (Type[FSDPTestModel]): A model class that inherits from 1310 ``FSDPTestModel``, which defines the expected interface. 1311 fsdp_init_mode (FSDPInitMode): The mode to initialize the 1312 FSDP-wrapped model. This should not be ``NO_FSDP``. 1313 ref_init_fn (Optional[Callable]): A callable to invoke that wraps a 1314 non-wrapped model to construct the reference model, where this 1315 wrapper should provide data parallel semantics. If ``None``, 1316 then the callable defaults to the DDP constructor. 1317 """ 1318 assert ( 1319 fsdp_init_mode != FSDPInitMode.NO_FSDP 1320 ), "Expects an FSDP init mode that wraps with FSDP" 1321 if init_kwargs is None: 1322 init_kwargs = {} 1323 lr = 1e-2 1324 rank = self.process_group.rank() 1325 # Establish reference behavior with DDP 1326 model = model_class.init( 1327 self.process_group, 1328 FSDPInitMode.NO_FSDP, 1329 CUDAInitMode.CUDA_BEFORE, 1330 deterministic=True, 1331 **init_kwargs, 1332 ) 1333 if ref_init_fn is None: 1334 ref_model = DDP(model, device_ids=[rank], output_device=rank) 1335 else: 1336 ref_model = ref_init_fn(model) 1337 if use_pure_fp16: 1338 ref_model = ref_model.half() 1339 ref_loss = self._train_for_several_steps( 1340 ref_model, 1341 num_iters, 1342 autocast=mixed_precision is not None, 1343 lr=lr, 1344 fsdp_cpu_offload=cpu_offload, 1345 mixed_precision=mixed_precision, 1346 enable_sharded_grad_scaler=enable_sharded_grad_scaler, 1347 use_pure_fp16=use_pure_fp16, 1348 sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs, 1349 ) 1350 ddp_params = list(ref_model.parameters()) 1351 # Check against FSDP behavior 1352 fsdp_kwargs.update( 1353 { 1354 "cpu_offload": cpu_offload, 1355 "backward_prefetch": backward_prefetch, 1356 "sharding_strategy": sharding_strategy, 1357 "mixed_precision": mixed_precision, 1358 "forward_prefetch": forward_prefetch, 1359 "use_orig_params": use_orig_params, 1360 } 1361 ) 1362 try: 1363 fsdp_model = model_class.init( 1364 self.process_group, 1365 fsdp_init_mode, 1366 cuda_init_mode, 1367 fsdp_kwargs, 1368 deterministic=True, 1369 **init_kwargs, 1370 ) 1371 except Exception as e: 1372 raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e 1373 if not isinstance(fsdp_model, FSDP): 1374 # Enforce that we wrap with top-level FSDP since we are comparing 1375 # assuming a data parallel reference and some test models may not 1376 # do so in their `init()` method 1377 fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) 1378 if use_pure_fp16: 1379 # Change the model parameter dtype after FSDP initialization 1380 fsdp_model = fsdp_model.half() 1381 if cuda_init_mode == CUDAInitMode.CUDA_AFTER: 1382 fsdp_model = fsdp_model.cuda() 1383 offload_params = cpu_offload is not None and cpu_offload.offload_params 1384 # Offloading parameters with `CUDA_AFTER` should raise an error during 1385 # lazy initialization due to the parameter devices not being CPU; 1386 # otherwise, all parameter devices should be CPU 1387 expects_device_error = ( 1388 offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER 1389 ) 1390 expects_cpu_device = ( 1391 offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER 1392 ) 1393 if expects_cpu_device: 1394 cpu_device = torch.device("cpu") 1395 for param in fsdp_model.parameters(): 1396 self.assertEqual(param.device, cpu_device) 1397 context = ( 1398 self.assertRaisesRegex( 1399 RuntimeError, 1400 "An FSDP-managed module with parameter CPU offloading enabled " 1401 "has parameters on cuda", 1402 ) 1403 if expects_device_error 1404 else nullcontext() 1405 ) 1406 with context: 1407 fsdp_loss = self._train_for_several_steps( 1408 fsdp_model, 1409 num_iters, 1410 autocast=False, 1411 lr=lr, 1412 fsdp_cpu_offload=cpu_offload, 1413 save_model=save_model, 1414 mixed_precision=mixed_precision, 1415 enable_sharded_grad_scaler=enable_sharded_grad_scaler, 1416 use_pure_fp16=use_pure_fp16, 1417 sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs, 1418 ) 1419 # No need to check for parameter and loss parity if expecting an error 1420 if expects_device_error: 1421 return 1422 # Check parameter devices are CPU if offloading to CPU before calling 1423 # `get_full_params()`, which will cast the parameters to FP32 1424 if offload_params: 1425 cpu_device = torch.device("cpu") 1426 for param in fsdp_model.parameters(): 1427 self.assertEqual(param.device, cpu_device) 1428 fsdp_loss = fsdp_loss.cuda() 1429 fsdp_unsharded_params = get_full_params(fsdp_model) 1430 # Do not check dtype since the reference DDP loss may not be the same 1431 # dtype as the FSDP loss in the case of mixed precision 1432 torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False) 1433 # Do not check for parameter parity if using mixed precision since (1) 1434 # the DDP parameters are in FP16 (from `half()`) while the FSDP 1435 # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs 1436 # the optimizer in FP16 while FSDP runs it in FP32 1437 # TODO: Disable checking the parameters for pure FP16 due to floating 1438 # point inaccuracy. Note that this means that the backward pass is not 1439 # checked: https://github.com/pytorch/pytorch/issues/90784 1440 if mixed_precision is None and not use_pure_fp16: 1441 self.assertEqual( 1442 ddp_params, 1443 fsdp_unsharded_params, 1444 exact_device=True, 1445 msg="FSDP did not match DDP", 1446 ) 1447 1448 1449def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None): 1450 def fully_shard_with_compiled_compute(*args, **kwargs): 1451 torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator] 1452 if compile_compute_on_module is None or isinstance( 1453 args[0], compile_compute_on_module 1454 ): 1455 args[0].compile() 1456 1457 class FullyShardMode(Enum): 1458 EAGER = auto() 1459 COMPILED_COMPUTE = auto() 1460 1461 def decorator(func): 1462 @wraps(func) 1463 def wrapper(*args, **kwargs): 1464 original_fully_shard = torch.distributed._composable.fsdp.fully_shard 1465 for mode in FullyShardMode: 1466 if mode != FullyShardMode.EAGER and not has_triton(): 1467 warnings.warn("Inductor on GPU needs Triton and recent GPU arch") 1468 continue 1469 # barrier to ensure thread reading the same value 1470 original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks 1471 original_compile_threads = torch._inductor.config.compile_threads 1472 torch.distributed.barrier() 1473 1474 if mode == FullyShardMode.EAGER: 1475 fully_shard_patch = original_fully_shard 1476 elif mode == FullyShardMode.COMPILED_COMPUTE: 1477 torch._dynamo.config.skip_fsdp_hooks = True 1478 torch._inductor.config.compile_threads = 1 1479 fully_shard_patch = fully_shard_with_compiled_compute # type: ignore[assignment] 1480 else: 1481 raise NotImplementedError( 1482 f"Need to implement FullyShardMode={mode}" 1483 ) 1484 1485 # fully_shard is imported as a global 1486 # through `from ... import fully_shard` 1487 func.__globals__[original_fully_shard.__name__] = fully_shard_patch 1488 func(*args, **kwargs) 1489 # other threads use patched func before this thread restores 1490 torch.distributed.barrier() 1491 func.__globals__[original_fully_shard.__name__] = original_fully_shard 1492 torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks 1493 torch._inductor.config.compile_threads = original_compile_threads 1494 1495 return wrapper 1496 1497 return decorator 1498 1499 1500class SkipModule(nn.Module): 1501 def __init__(self) -> None: 1502 super().__init__() 1503 self.lin = nn.Linear(10, 10, bias=False) 1504 1505 def forward(self, x): 1506 return self.lin(x) 1507 1508 1509class NestedLinear(nn.Module): 1510 def __init__(self, fsdp_wrap): 1511 super().__init__() 1512 if fsdp_wrap: 1513 self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda()) 1514 else: 1515 self.nested_linear = nn.Linear(10, 10, bias=False).cuda() 1516 1517 def forward(self, x): 1518 return self.nested_linear(x) 1519 1520 1521class SkipModel(nn.Module): 1522 def __init__(self, double_nest): 1523 super().__init__() 1524 self.linear = nn.Linear(10, 10, bias=False).cuda() 1525 self.linear_skip = SkipModule().cuda() 1526 self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest)) 1527 1528 def forward(self, x): 1529 x = self.linear(x) 1530 x = self.linear_skip(x) 1531 x = self.nested_linear(x) 1532 return x 1533