xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_fsdp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["oncall: distributed"]
3
4import contextlib
5import os
6import re
7import sys
8import warnings
9from abc import ABC, abstractmethod
10from contextlib import nullcontext
11from copy import deepcopy
12from enum import auto, Enum
13from functools import wraps
14from typing import (
15    Any,
16    Callable,
17    Dict,
18    List,
19    no_type_check,
20    Optional,
21    Tuple,
22    Type,
23    Union,
24)
25from unittest import mock
26
27import torch
28import torch.distributed as dist
29import torch.nn as nn
30import torch.nn.functional as F
31from torch.distributed._composable import checkpoint
32from torch.distributed._composable.fsdp import fully_shard
33from torch.distributed._composable.fsdp._fsdp_param_group import (
34    FSDPParamGroup,
35    RegisterPostBackwardFunction,
36)
37from torch.distributed.device_mesh import DeviceMesh
38from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
39from torch.distributed.fsdp._common_utils import TrainingState
40from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
41from torch.distributed.fsdp.fully_sharded_data_parallel import (
42    BackwardPrefetch,
43    MixedPrecision,
44    ShardingStrategy,
45)
46from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
47from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap
48from torch.distributed.tensor import distribute_tensor, DTensor, Shard
49from torch.distributed.tensor.parallel import (
50    ColwiseParallel,
51    parallelize_module,
52    RowwiseParallel,
53    SequenceParallel,
54)
55from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
56from torch.nn.parallel.distributed import DistributedDataParallel as DDP
57from torch.testing._internal.common_distributed import (
58    MultiProcessTestCase,
59    MultiThreadedTestCase,
60    run_subtests,
61    TEST_SKIPS,
62)
63from torch.testing._internal.common_utils import FILE_SCHEMA, get_cycles_per_ms
64from torch.utils._triton import has_triton
65
66
67class FSDPInitMode(Enum):
68    # No FSDP wrapping
69    NO_FSDP = auto()
70    # FSDP recursive wrapping
71    RECURSIVE = auto()
72    # TODO: FSDP non-recursive wrapping
73    # NONRECURSIVE = auto()
74
75
76class CUDAInitMode(Enum):
77    # Move model to CUDA before passing to the FSDP constructor
78    CUDA_BEFORE = auto()
79    # Move model to CUDA after passing to the FSDP constructor
80    CUDA_AFTER = auto()
81    # Keep on CPU
82    CUDA_NEVER = auto()
83
84
85class FSDPTestModel(nn.Module, ABC):
86    """This defines the interface expected from all models used commonly for
87    FSDP unit tests."""
88
89    @abstractmethod
90    def get_input(self, device) -> Tuple[torch.Tensor, ...]:
91        """Returns an input for the model as as tuple."""
92        ...
93
94    @abstractmethod
95    def get_loss(self, input, output) -> torch.Tensor:
96        """Returns the loss given the input and output."""
97        ...
98
99    @abstractmethod
100    def run_backward(self, loss) -> None:
101        """Runs the backward pass (e.g. including ``loss.backward()``)."""
102        ...
103
104    @staticmethod
105    @abstractmethod
106    def init(*args: Any, **kwargs: Any) -> nn.Module:
107        """Initializes an instance of this model."""
108        ...
109
110
111def _assert_module_states(
112    model: nn.Module,
113    process_group: dist.ProcessGroup,
114    assert_fn: Callable,
115):
116    """
117    All-gathers module states across ranks and calls ``assert_fn`` on each pair
118    of corresponding states from rank 0 and a nonzero rank. For example, if
119    ``assert_fn`` is ``self.assertEqual()``, then this checks that all module
120    states are equal across ranks.
121    """
122    # Include names for debugging convenience
123    named_module_states = [
124        (param_name, param.detach().cpu())
125        for param_name, param in model.named_parameters()
126    ]
127    named_module_states += [
128        (buffer_name, buffer.detach().cpu())
129        for buffer_name, buffer in model.named_buffers()
130    ]
131    world_size = dist.get_world_size(process_group)
132    olist = [None for _ in range(world_size)]
133    dist.all_gather_object(olist, named_module_states, group=process_group)
134    rank0_states = olist[0]
135    assert rank0_states is not None  # mypy
136    for state in olist[1:]:
137        assert state is not None  # mypy
138        for (_, p1), (_, p2) in zip(rank0_states, state):
139            assert_fn(p1, p2)
140
141
142def _zero_model(
143    model: nn.Module,
144    zero_buffers: bool = False,
145    summon_full=True,
146):
147    """Zeros the parameters and optionally buffers of ``model`` in place."""
148    ctx = FSDP.summon_full_params(model) if summon_full else nullcontext()
149    with ctx:
150        for param in model.parameters():
151            with torch.no_grad():
152                param.zero_()
153        if zero_buffers:
154            for buffer in model.buffers():
155                with torch.no_grad():
156                    buffer.zero_()
157
158
159def _get_state_dict(model, cpu_offload=False, half=False):
160    if not cpu_offload:
161        model = model.cuda()
162    if half:
163        model.half()
164
165    return model.state_dict()
166
167
168def subtest_name(test_name_mapping, *args):
169    return "_".join(
170        [test_name_mapping[str(s)] if s is not None else "none" for s in args]
171    )
172
173
174def _broadcast_state_dict(rank, state_dict):
175    # For non-FSDP roots, some parts of the model state on rank 0 may
176    # not be on CPU, so we move everything to CPU to avoid issues like:
177    # https://github.com/pytorch/pytorch/issues/77113.
178    for param_name, param in state_dict.items():
179        if param.device != torch.device("cpu"):
180            state_dict[param_name] = param.cpu()
181
182    olist = [state_dict if rank == 0 else None]
183    dist.broadcast_object_list(olist)
184    state_dict = olist[0]
185    # Ensure that the state is on CUDA
186    for param_name in state_dict.keys():
187        state_dict[param_name] = state_dict[param_name].cuda()
188    return state_dict
189
190
191def get_full_params(model: nn.Module, recurse: bool = True):
192    """
193    Returns the full unsharded parameters of ``model``. Any FSDP-managed
194    parameters offloaded to CPU are moved to GPU in the returned list.
195
196    Args:
197        recurse (bool): If ``False``, only unshards the parameters immediate to
198            ``model``; if ``True``, recurses through the module hierarchy
199            rooted at ``model``.
200    """
201    with FSDP.summon_full_params(model, recurse=recurse):
202        return deepcopy(list(model.parameters()))
203
204
205def _maybe_cuda(model: nn.Module, move_to_cuda: bool):
206    return model.cuda() if move_to_cuda else model
207
208
209def _maybe_wrap_fsdp(model: nn.Module, wrap_fsdp: bool, *args, **kwargs):
210    return model if not wrap_fsdp else FSDP(model, *args, **kwargs)
211
212
213class DummyProcessGroup:
214    def __init__(self, rank: int, size: int):
215        self._rank = rank
216        self._size = size
217
218    def rank(self) -> int:
219        return self._rank
220
221    def size(self) -> int:
222        return self._size
223
224    def allreduce(self, *args, **kwargs):
225        dist_wait = mock.Mock()
226
227        def get_future():
228            future: torch.futures.Future = torch.futures.Future()
229            future.set_result(1)
230            return future
231
232        dist_wait.get_future = get_future
233        return dist_wait
234
235
236class TransformerWithSharedParams(FSDPTestModel):
237    def __init__(
238        self,
239        group: dist.ProcessGroup,
240        cuda_init_mode: CUDAInitMode,
241        add_bn: bool,
242        deterministic: bool,
243    ):
244        super().__init__()
245        self.rank = group.rank()
246        self.world_size = group.size()
247        if deterministic:
248            torch.manual_seed(0)
249        d_vocab = 23
250        d_model = 16
251
252        self.embed_tokens = nn.Embedding(d_vocab, d_model)
253        self.transformer = nn.Transformer(
254            d_model=d_model,
255            num_encoder_layers=2,
256            num_decoder_layers=2,
257            dim_feedforward=8,
258            dropout=0.1,
259        )
260        self.output_proj = nn.Linear(d_model, d_vocab)
261
262        # share the embedding and output projection weights
263        self.output_proj.weight = self.embed_tokens.weight
264        self.register_buffer(
265            "vocab_bias", self.embed_tokens.weight.new_ones((d_model,))
266        )
267        self.register_buffer(
268            "long_buffer",
269            torch.zeros_like(self.vocab_bias, dtype=torch.long),
270        )  # type: ignore[arg-type]
271
272        self.bs = 2
273        self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
274        if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
275            self = self.cuda()
276        if deterministic:
277            self.eval()
278
279    def get_input(self, device):
280        torch.manual_seed(1 + self.rank)  # keep everything deterministic
281        src = torch.arange(12, device=device).view(6, self.bs)  # T x B
282        tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs)  # T x B
283        return (src, tgt)
284
285    def forward(self, src_ids, tgt_ids):
286        src = self.embed_tokens(src_ids)
287        src = src + self.vocab_bias + self.long_buffer.type_as(src)  # type: ignore[operator]
288        tgt = self.embed_tokens(tgt_ids)
289        tgt = self.bn(tgt)
290        x = self.transformer(src, tgt)
291        return self.output_proj(x)
292
293    def get_loss(self, input, output):
294        _, tgt = input
295        return nn.functional.cross_entropy(
296            output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum"
297        )
298
299    def run_backward(self, loss):
300        loss.backward()
301
302    @staticmethod
303    def init(
304        group: dist.ProcessGroup,
305        fsdp_init_mode: FSDPInitMode,
306        cuda_init_mode: CUDAInitMode,
307        fsdp_kwargs: Optional[Dict[str, Any]] = None,
308        deterministic: bool = False,
309        add_bn: bool = True,
310    ) -> Union[nn.Module, FSDP]:
311        """
312        Initializes a :class:`TransformerWithSharedParams` instance.
313
314        Args:
315            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
316                any modules with FSDP. If ``RECURSIVE``, then wraps with
317                top-level FSDP. By default, the top-level FSDP uses the
318                ``ModuleWrapPolicy`` for encoder and decoder layers, but a
319                different auto wrap policy may be specified via
320                ``fsdp_kwargs``.
321            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
322            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
323                forwarded to the FSDP constructor.
324            deterministic (bool): Whether to make the model deterministic
325                across constructions.
326            add_bn (bool): Whether to include batch norm in the model.
327        """
328
329        if fsdp_kwargs is None:
330            fsdp_kwargs = {}
331        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
332            if isinstance(group, tuple):
333                pg = group[0]
334            else:
335                pg = group
336            return TransformerWithSharedParams(
337                pg, cuda_init_mode, add_bn, deterministic
338            )
339        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
340            # Default to the `ModuleWrapPolicy`
341            if "auto_wrap_policy" not in fsdp_kwargs:
342                auto_wrap_policy = ModuleWrapPolicy(
343                    {
344                        TransformerEncoderLayer,
345                        TransformerDecoderLayer,
346                    }
347                )
348            else:
349                auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy")
350
351            if (
352                "sharding_strategy" in fsdp_kwargs
353                and fsdp_kwargs["sharding_strategy"]
354                in {ShardingStrategy.HYBRID_SHARD, ShardingStrategy._HYBRID_SHARD_ZERO2}
355                and not isinstance(group, tuple)
356            ):
357                fsdp_pg = None
358            else:
359                fsdp_pg = group
360
361            if isinstance(group, tuple):
362                tformer_pg = group[0]
363            else:
364                tformer_pg = group
365
366            m = TransformerWithSharedParams(
367                tformer_pg, cuda_init_mode, add_bn, deterministic
368            )
369            fsdp_model = FSDP(
370                m,
371                fsdp_pg,
372                auto_wrap_policy=auto_wrap_policy,
373                **fsdp_kwargs,
374            )
375            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
376                fsdp_model = fsdp_model.cuda()
377            return fsdp_model
378        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
379
380    def get_ignored_modules(self):
381        return [self.transformer]
382
383
384class NestedWrappedModule(FSDPTestModel):
385    def __init__(
386        self,
387        group: dist.ProcessGroup,
388        wrap_fsdp: bool,
389        cuda_init_mode: CUDAInitMode,
390        deterministic: bool,
391        **fsdp_kwargs,
392    ):
393        super().__init__()
394        self.rank = group.rank()
395        self.world_size = group.size()
396        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
397
398        def _maybe_wrap(layer):
399            if wrap_fsdp:
400                return FSDP(layer, group, **fsdp_kwargs)
401            return layer
402
403        if deterministic:
404            torch.manual_seed(0)
405        self.module = nn.Sequential(
406            _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
407            _maybe_wrap(
408                nn.Sequential(
409                    _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
410                    _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
411                ),
412            ),
413            _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
414            _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
415        )
416
417    def get_input(self, device):
418        torch.manual_seed(1 + self.rank)  # keep everything deterministic
419        return (torch.rand(4, 8, device=device),)
420
421    def forward(self, x):
422        return self.module(x)
423
424    def get_loss(self, input, output):
425        loss = output.sum()
426        return loss
427
428    def run_backward(self, loss):
429        loss.backward()
430
431    @staticmethod
432    def init(
433        group: dist.ProcessGroup,
434        fsdp_init_mode: FSDPInitMode,
435        cuda_init_mode: CUDAInitMode,
436        fsdp_kwargs: Optional[Dict[str, Any]] = None,
437        deterministic: bool = False,
438    ) -> nn.Module:
439        """
440        Initializes a :class:`NestedWrappedModule` instance.
441
442        Args:
443            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
444                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
445                modules with FSDP but not the top-level module. The model may
446                later be wrapped with a top-level FSDP external to this method
447                if desired.
448            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
449            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
450                forwarded to the FSDP constructor.
451            deterministic (bool): Whether to make the model deterministic
452                across constructions.
453        """
454        if fsdp_kwargs is None:
455            fsdp_kwargs = {}
456        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
457            return NestedWrappedModule(
458                group,
459                wrap_fsdp=False,
460                cuda_init_mode=cuda_init_mode,
461                deterministic=deterministic,
462            )
463        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
464            # Does not wrap with top-level FSDP
465            fsdp_model = NestedWrappedModule(
466                group,
467                wrap_fsdp=True,
468                cuda_init_mode=cuda_init_mode,
469                deterministic=deterministic,
470                **fsdp_kwargs,
471            )
472            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
473                fsdp_model = fsdp_model.cuda()
474            return fsdp_model
475        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
476
477
478class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
479    @staticmethod
480    def init(
481        group: dist.ProcessGroup,
482        fsdp_init_mode: FSDPInitMode,
483        cuda_init_mode: CUDAInitMode,
484        fsdp_kwargs: Optional[Dict[str, Any]] = None,
485        deterministic: bool = False,
486    ):
487        """
488        Initializes a :class:`NestedWrappedModule` instance, but unlike
489        :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
490        wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
491        policy.
492        """
493        model = super(
494            AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule
495        ).init(
496            group=group,
497            fsdp_init_mode=FSDPInitMode.NO_FSDP,
498            cuda_init_mode=cuda_init_mode,
499            fsdp_kwargs=fsdp_kwargs,
500            deterministic=deterministic,
501        )
502        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
503            return model
504        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
505            fsdp_kwargs = fsdp_kwargs or {}
506            fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs)
507            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
508                fsdp_model = fsdp_model.cuda()
509            return fsdp_model
510
511
512class NonUniformReqGradNWM(NestedWrappedModule):
513    def __init__(
514        self,
515        group: dist.ProcessGroup,
516        wrap_fsdp: bool,
517        cuda_init_mode: CUDAInitMode,
518        deterministic: bool,
519        **fsdp_kwargs,
520    ):
521        super(NestedWrappedModule, self).__init__()
522        # This `__init__` only differs from `NestedWrappedModule.__init__` in that
523        # the last two `nn.Linear` layers are FSDP wrapped in a `nn.Sequential`
524        # container. This arrangement results in all elements of the last two parameters
525        # residing on a single rank. Freezing all parameters except those two allows us
526        # to verify that `ShardedGradScaler` accommodates situations where some ranks
527        # have no (non-zero sized) parameter shards.
528        self.rank = group.rank()
529        self.world_size = group.size()
530        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
531
532        def _maybe_wrap(layer):
533            if wrap_fsdp:
534                return FSDP(layer, group, **fsdp_kwargs)
535            return layer
536
537        if deterministic:
538            torch.manual_seed(0)
539        self.module = nn.Sequential(
540            _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
541            _maybe_wrap(
542                nn.Sequential(
543                    _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
544                    _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
545                ),
546            ),
547            _maybe_wrap(
548                nn.Sequential(
549                    _maybe_cuda(nn.Linear(16, 4), move_to_cuda),
550                    _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
551                ),
552            ),
553        )
554
555    @staticmethod
556    def _set_nonuniform_req_grad(model, req_grad_mask) -> None:
557        for n, p in model.named_parameters():
558            if not re.match(req_grad_mask, n):
559                p.requires_grad_(False)
560
561    @staticmethod
562    def init(
563        group: dist.ProcessGroup,
564        fsdp_init_mode: FSDPInitMode,
565        cuda_init_mode: CUDAInitMode,
566        fsdp_kwargs: Optional[Dict[str, Any]] = None,
567        deterministic: bool = False,
568    ):
569        """
570        Initializes a :class:`NestedWrappedModule` instance, but unlike
571        :meth:`NestedWrappedModule.init`, it wraps a second :class:`torch.nn.Sequential`
572        container to enable the desired non-uniform ``requires_grad``
573        ``use_orig_params=True`` tests. For both ``RECURSIVE`` and ``NO_FSDP``
574        init modes, freezes all parameters except the last two to validate
575        ``ShardedGradScaler`` support for ranks with no (non-zero sized) local shards in
576        FSDP ``use_orig_params=True`` mode.
577        """
578        # The parameters that should remain unfrozen are in `module.2.1`. The regex
579        # pattern below matches the relevant parameter names both with and without
580        # an interstitial FSDP module indicator (`_fsdp_wrapped_module`) present.
581        req_grad_pattern = re.compile(r"module\.2.*\.1.*")
582        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
583            ddp_model = NonUniformReqGradNWM(
584                group,
585                wrap_fsdp=False,
586                cuda_init_mode=cuda_init_mode,
587                deterministic=deterministic,
588            )
589            NonUniformReqGradNWM._set_nonuniform_req_grad(ddp_model, req_grad_pattern)
590            return ddp_model
591        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
592            if fsdp_kwargs is None:
593                fsdp_kwargs = {}
594            fsdp_model = NonUniformReqGradNWM(
595                group,
596                wrap_fsdp=True,
597                cuda_init_mode=cuda_init_mode,
598                deterministic=deterministic,
599                **fsdp_kwargs,
600            )
601            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
602                fsdp_model = fsdp_model.cuda()
603            NonUniformReqGradNWM._set_nonuniform_req_grad(fsdp_model, req_grad_pattern)
604            return fsdp_model
605        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
606
607
608class ModuleWithDelay(FSDPTestModel):
609    """This class wraps a :class:`FSDPTestModel` to optionally add a delay
610    after computing the loss and/or before the gradient reduction."""
611
612    def __init__(
613        self,
614        module: nn.Module,
615        delay_after_loss_ms: int,
616        delay_before_reduction_ms: int,
617    ):
618        super().__init__()
619        self.delay_after_loss_ms = delay_after_loss_ms
620        self.delay_before_reduction_ms = delay_before_reduction_ms
621        self.module = module
622
623    def get_input(self, device):
624        return self.module.get_input(device)
625
626    def forward(self, x):
627        return self.module(x)
628
629    def get_loss(self, input, output):
630        loss = self.module.get_loss(input, output)
631        if self.delay_after_loss_ms > 0:
632            torch.cuda._sleep(int(self.delay_after_loss_ms * get_cycles_per_ms()))
633        return loss
634
635    def run_backward(self, loss):
636        orig_reduce_scatter = torch.distributed.reduce_scatter_tensor
637
638        def _delayed_reduce_scatter(*args, **kwargs):
639            if self.delay_before_reduction_ms > 0:
640                torch.cuda._sleep(
641                    int(self.delay_before_reduction_ms * get_cycles_per_ms())
642                )
643            return orig_reduce_scatter(*args, **kwargs)
644
645        with mock.patch(
646            "torch.distributed.reduce_scatter_tensor", _delayed_reduce_scatter
647        ):
648            self.module.run_backward(loss)
649
650    @staticmethod
651    def init(
652        module_class: Type[FSDPTestModel],
653        *model_args: Any,
654        delay_after_loss_ms: int,
655        delay_before_reduction_ms: int,
656        **model_kwargs: Any,
657    ):
658        """
659        Args:
660            module_class (Type[FSDPTestModel]): Wrapped module class to which
661                to add delays.
662            model_args: Positional arguments forwarded to the ``module_class``
663                ``init()``.
664            delay_after_loss_ms (int): Delay after computing the loss/before
665                the optimizer step (in ms).
666            delay_before_reduction_ms (int): Delay before reduce-scattering
667                gradients (in ms).
668            model_kwargs: Keyword arguments forwarded to the ``module_class``
669                ``init()``.
670        """
671        return ModuleWithDelay(
672            module_class.init(*model_args, **model_kwargs),
673            delay_after_loss_ms,
674            delay_before_reduction_ms,
675        )
676
677
678class NestedWrappedModuleWithDelay(ModuleWithDelay):
679    @staticmethod
680    def init(  # type: ignore[override]
681        group: dist.ProcessGroup,
682        fsdp_init_mode: FSDPInitMode,
683        cuda_init_mode: CUDAInitMode = CUDAInitMode.CUDA_AFTER,
684        fsdp_kwargs: Optional[Dict[str, Any]] = None,
685        deterministic: bool = False,
686        delay_after_loss_ms: int = 0,
687        delay_before_reduction_ms: int = 0,
688    ):
689        return ModuleWithDelay.init(
690            NestedWrappedModule,
691            group=group,
692            fsdp_init_mode=fsdp_init_mode,
693            cuda_init_mode=cuda_init_mode,
694            fsdp_kwargs=fsdp_kwargs,
695            deterministic=deterministic,
696            delay_after_loss_ms=delay_after_loss_ms,
697            delay_before_reduction_ms=delay_before_reduction_ms,
698        )
699
700
701class DummyDDP(nn.Module):
702    def __init__(self, module):
703        super().__init__()
704        self.module = module
705
706    def forward(self, *args, **kwargs):
707        return self.module(*args, **kwargs)
708
709
710class MixtureOfExperts(NestedWrappedModule):
711    def __init__(
712        self,
713        group: dist.ProcessGroup,
714        wrap_fsdp: bool,
715        cuda_init_mode: CUDAInitMode,
716        delay_before_free_ms: int,
717        deterministic: bool,
718        **fsdp_kwargs,
719    ):
720        super().__init__(
721            group=group,
722            wrap_fsdp=wrap_fsdp,
723            cuda_init_mode=cuda_init_mode,
724            deterministic=deterministic,
725        )
726        self.group = group
727        self.delay_before_free_ms = delay_before_free_ms
728        self.wrap_fsdp = wrap_fsdp
729        self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
730        if deterministic:
731            # Give each rank different expert parameters
732            torch.manual_seed(42 + self.rank)
733        d_expert = 23
734        d_shared = 12
735        d_input = 8
736        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)
737
738        self.num_expert_params = sum(p.numel() for p in expert.parameters())
739        for p in expert.parameters():
740            p.expert = True  # type: ignore[attr-defined]
741
742        if deterministic:
743            # Keep all other parameters the same across ranks
744            torch.manual_seed(0)
745
746        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)
747
748        if wrap_fsdp:
749            # we create a process group of size 1 for the expert params
750            expert_group = torch.distributed.new_group(
751                [group.rank()]
752            )  # world size 1 means no shard
753            expert = FSDP(expert, expert_group, **fsdp_kwargs)  # type: ignore[assignment]
754            shared = FSDP(shared, group, **fsdp_kwargs)  # type: ignore[assignment]
755
756        self.module = nn.Sequential(
757            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
758            shared,
759            expert,
760            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda),
761        )
762
763    def forward(self, x):
764        if self.delay_before_free_ms > 0:
765            expert = self.module[2]
766            if isinstance(expert, FSDP):
767                orig_reshard = torch.distributed.fsdp._runtime_utils._reshard
768
769                def _delayed_reshard(*args, **kwargs):
770                    torch.cuda._sleep(
771                        int(self.delay_before_free_ms * get_cycles_per_ms())
772                    )
773                    return orig_reshard(*args, **kwargs)
774
775                # This patch covers any `import torch..._reshard` uses.
776                with mock.patch(
777                    "torch.distributed.fsdp._runtime_utils._reshard", _delayed_reshard
778                ):
779                    return self.module(x)
780
781        return self.module(x)
782
783    def run_backward(self, loss):
784        loss.backward()
785        # Manually reduce gradients if not wrapped in FullyShardedDataParallel
786        if not self.wrap_fsdp:
787            with torch.no_grad():
788                for p in self.parameters():
789                    if hasattr(p, "expert"):
790                        continue  # these params don't need grad reduction
791                    if p.grad is not None:
792                        p.grad.div_(self.world_size)
793                        torch.distributed.all_reduce(p.grad, group=self.group)
794
795    @staticmethod
796    def init(
797        group: dist.ProcessGroup,
798        fsdp_init_mode: FSDPInitMode,
799        cuda_init_mode: CUDAInitMode,
800        fsdp_kwargs: Optional[Dict[str, Any]] = None,
801        deterministic: bool = False,
802        delay_before_free_ms: int = 0,
803    ):
804        """
805        Initializes a :class:`MixtureOfExperts` instance.
806
807        Args:
808            fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap
809                any modules with FSDP. If ``RECURSIVE``, then wraps some nested
810                modules with FSDP, including the expert and shared layers, but
811                not the top-level module. The model may later be wrapped with a
812                top-level FSDP external to this method if desired.
813            cuda_init_mode (CUDAInitMode): Determines model movement to CUDA.
814            fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments
815                forwarded to the FSDP constructor.
816            deterministic (bool): Whether to make the model deterministic
817                across constructions.
818            delay_before_free_ms (int): Delay before resharding expert
819                parameters in the forward pass (in ms).
820        """
821        if fsdp_kwargs is None:
822            fsdp_kwargs = {}
823        if fsdp_init_mode == FSDPInitMode.NO_FSDP:
824            return MixtureOfExperts(
825                group,
826                wrap_fsdp=False,
827                cuda_init_mode=cuda_init_mode,
828                delay_before_free_ms=delay_before_free_ms,
829                deterministic=deterministic,
830            )
831        elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
832            # Does not wrap with top-level FSDP
833            fsdp_model = MixtureOfExperts(
834                group,
835                wrap_fsdp=True,
836                cuda_init_mode=cuda_init_mode,
837                delay_before_free_ms=delay_before_free_ms,
838                deterministic=deterministic,
839                **fsdp_kwargs,
840            )
841            if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
842                fsdp_model = fsdp_model.cuda()
843            return fsdp_model
844        raise ValueError(f"Unsupported FSDP init mode: {fsdp_init_mode}")
845
846
847class MLP(nn.Module):
848    def __init__(
849        self,
850        dim: int,
851        device: Optional[torch.device] = None,
852        *,
853        bias: bool = True,
854        with_buffer: bool = False,
855        dim_multiplier: int = 4,
856    ):
857        super().__init__()
858        self.in_proj = nn.Linear(dim, dim_multiplier * dim, device=device, bias=bias)
859        self.out_proj = nn.Linear(dim_multiplier * dim, dim, device=device, bias=bias)
860        if with_buffer:
861            self.register_buffer("buffer", torch.randn((dim,), device=device))
862        else:
863            self.buffer = None
864
865    def forward(self, x: torch.Tensor) -> torch.Tensor:
866        z = self.in_proj(x)
867        z = F.relu(z)
868        z = self.out_proj(z)
869        z = F.relu(z)
870        if self.buffer is not None:
871            z = z + self.buffer
872        return z
873
874    def reset_parameters(self):
875        if self.buffer is not None:
876            torch.nn.init.normal_(self.buffer)
877
878
879class MLPStack(nn.Sequential):
880    def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
881        modules: List[nn.Module] = [
882            # Use multiplier of 3 to exercise uneven case
883            MLP(mlp_dim, dim_multiplier=3),
884            MLP(mlp_dim),
885            MLP(mlp_dim, dim_multiplier=3),
886        ]
887        if with_seq_parallel:
888            modules.append(nn.LayerNorm(mlp_dim, bias=False))
889        super().__init__(*modules)
890        self.with_seq_parallel = with_seq_parallel
891
892    def parallelize(
893        self,
894        tp_mesh: DeviceMesh,
895        dp_mesh: DeviceMesh,
896        use_activation_checkpointing: bool,
897        **fsdp_kwargs,
898    ) -> "MLPStack":
899        parallelize_plan = {
900            # Pass `use_local_output=False` to keep as DTensor to preserve
901            # uneven activation dims
902            "0.in_proj": ColwiseParallel(use_local_output=False),
903            "0.out_proj": RowwiseParallel(use_local_output=False),
904            "1.in_proj": ColwiseParallel(use_local_output=False),
905            "1.out_proj": RowwiseParallel(use_local_output=False),
906            "2.in_proj": ColwiseParallel(use_local_output=False),
907            "2.out_proj": RowwiseParallel(output_layouts=Shard(1))
908            if self.with_seq_parallel
909            else RowwiseParallel(),
910        }
911        if self.with_seq_parallel:
912            parallelize_plan["3"] = SequenceParallel(sequence_dim=1)
913        parallelize_module(self, device_mesh=tp_mesh, parallelize_plan=parallelize_plan)
914        for module in self:
915            if isinstance(module, nn.LayerNorm):
916                continue
917            if use_activation_checkpointing:
918                checkpoint(module)
919            fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
920        fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
921        return self
922
923
924class DoubleLinear(nn.Module):
925    """
926    This can be used for returning multiple outputs from a module
927    (``use_second_linear=True``) or for having an unused module (``False``).
928    """
929
930    def __init__(self, dim: int, use_second_linear: bool = True):
931        super().__init__()
932        self.lin1 = nn.Linear(dim, dim)
933        self.lin2 = nn.Linear(dim, dim)
934        self.relu = nn.ReLU()
935        self.use_second_linear = use_second_linear
936
937    def forward(
938        self, x: torch.Tensor
939    ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
940        if self.use_second_linear:
941            return self.relu(self.lin1(x)), self.relu(self.lin2(x))
942        return self.relu(self.lin1(x))
943
944
945# NOTE: For these patch methods, if we want safety under multi-threading (e.g.
946# when using multi-threaded process group), then we want:
947# (1) a barrier immediately after reading the original value to ensure that all
948# threads see the same original value
949# (2) a barrier immediately before restoring the original value to ensure that
950# all threads use the patched value inside the context
951@contextlib.contextmanager
952def patch_all_gather(new_all_gather_into_tensor: Callable):
953    orig_all_gather = dist.all_gather_into_tensor
954    dist.barrier()
955    dist.all_gather_into_tensor = new_all_gather_into_tensor
956    try:
957        yield
958    finally:
959        dist.barrier()
960        dist.all_gather_into_tensor = orig_all_gather
961
962
963@contextlib.contextmanager
964def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
965    orig_reduce_scatter = dist.reduce_scatter_tensor
966    dist.barrier()
967    dist.reduce_scatter_tensor = new_reduce_scatter_tensor
968    try:
969        yield
970    finally:
971        dist.barrier()
972        dist.reduce_scatter_tensor = orig_reduce_scatter
973
974
975@contextlib.contextmanager
976def patch_all_reduce(new_all_reduce: Callable):
977    orig_all_reduce = dist.all_reduce
978    dist.barrier()
979    dist.all_reduce = new_all_reduce
980    try:
981        yield
982    finally:
983        dist.barrier()
984        dist.all_reduce = orig_all_reduce
985
986
987@no_type_check
988@contextlib.contextmanager
989def patch_unshard(new_unshard: Callable):
990    orig_unshard = FSDPParamGroup.unshard
991    dist.barrier()
992    FSDPParamGroup.unshard = new_unshard
993    try:
994        yield
995    finally:
996        dist.barrier()
997        FSDPParamGroup.unshard = orig_unshard
998
999
1000@no_type_check
1001@contextlib.contextmanager
1002def patch_reshard(new_reshard: Callable):
1003    orig_reshard = FSDPParamGroup.reshard
1004    dist.barrier()
1005    FSDPParamGroup.reshard = new_reshard
1006    try:
1007        yield
1008    finally:
1009        dist.barrier()
1010        FSDPParamGroup.reshard = orig_reshard
1011
1012
1013@no_type_check
1014@contextlib.contextmanager
1015def patch_post_backward(new_post_backward: Callable):
1016    orig_post_backward = FSDPParamGroup.post_backward
1017    dist.barrier()
1018    FSDPParamGroup.post_backward = new_post_backward
1019    try:
1020        yield
1021    finally:
1022        dist.barrier()
1023        FSDPParamGroup.post_backward = orig_post_backward
1024
1025
1026@no_type_check
1027@contextlib.contextmanager
1028def patch_register_post_backward_hook_backward(new_backward: Callable):
1029    orig_backward = RegisterPostBackwardFunction.backward
1030    dist.barrier()
1031    RegisterPostBackwardFunction.backward = new_backward
1032    try:
1033        yield
1034    finally:
1035        dist.barrier()
1036        RegisterPostBackwardFunction.backward = orig_backward
1037
1038
1039def reduce_scatter_with_assert(
1040    cls,
1041    orig_reduce_scatter: Callable,
1042    assert_fn: Callable,  # `assert_fn(output: Tensor)`
1043    *args: Any,
1044    **kwargs: Any,
1045):
1046    if len(args) > 0:
1047        output = args[0]
1048    elif "output" in kwargs:
1049        output = kwargs["output"]
1050    else:
1051        raise AssertionError(
1052            f"Cannot get reduce-scatter output from\nargs: {args}\nkwargs: {kwargs}"
1053        )
1054    assert_fn(output)
1055    return orig_reduce_scatter(*args, **kwargs)
1056
1057
1058def check_sharded_parity(
1059    cls,  # unit test class
1060    replicated_module: nn.Module,
1061    sharded_module: nn.Module,
1062    prefixes_to_ignore: Tuple[str, ...] = (),
1063):
1064    for (replicated_name, replicated_param), (sharded_name, sharded_param) in zip(
1065        replicated_module.named_parameters(), sharded_module.named_parameters()
1066    ):
1067        clean_sharded_name = sharded_name
1068        for prefix in prefixes_to_ignore:
1069            clean_sharded_name = clean_sharded_name.replace(prefix, "")
1070        cls.assertEqual(replicated_name, clean_sharded_name)
1071        cls.assertIsInstance(sharded_param, DTensor)
1072        assert isinstance(sharded_param, DTensor)  # mypy
1073        mesh, placements = sharded_param.device_mesh, sharded_param.placements
1074        if tuple(placements) == (Shard(0), Shard(0)):
1075            raise AssertionError(
1076                "FSDP's (Shard(0), Shard(0)) layout differs from distribute_tensor(), "
1077                "so we cannot check for equality using it"
1078            )
1079        sharded_ref_param = distribute_tensor(replicated_param, mesh, placements)
1080        cls.assertEqual(sharded_param.to_local(), sharded_ref_param.to_local())
1081        if replicated_param.grad is None:
1082            cls.assertIsNone(sharded_param.grad)
1083            continue
1084        cls.assertIsNotNone(sharded_param.grad)
1085        sharded_ref_grad = distribute_tensor(replicated_param.grad, mesh, placements)
1086        cls.assertIsInstance(sharded_param.grad, DTensor)
1087        assert isinstance(sharded_param.grad, DTensor)  # mypy
1088        cls.assertEqual(sharded_param.grad.to_local(), sharded_ref_grad.to_local())
1089
1090
1091class FSDPTestMultiThread(MultiThreadedTestCase):
1092    @property
1093    def world_size(self):
1094        return torch.cuda.device_count() if torch.cuda.is_available() else 4
1095
1096    def setUp(self):
1097        super().setUp()
1098        self._spawn_threads()
1099
1100    def run_subtests(self, *args, **kwargs):
1101        return run_subtests(self, *args, **kwargs)
1102
1103    def perThreadSetUp(self):
1104        torch._dynamo.reset()
1105
1106    def perThreadTearDown(self):
1107        torch._dynamo.reset()
1108
1109
1110class FSDPTest(MultiProcessTestCase):
1111    def setUp(self):
1112        super().setUp()
1113        # Set TORCH_NCCL_DESYNC_DEBUG=0 to disable the NCCL `workCleanupLoop()`,
1114        # which can cause unit test flakiness:
1115        # https://github.com/pytorch/pytorch/issues/90848
1116        os.environ["TORCH_NCCL_DESYNC_DEBUG"] = "0"
1117        self._spawn_processes()
1118
1119    @property
1120    def world_size(self):
1121        return min(torch.cuda.device_count(), 8) if torch.cuda.is_available() else 4
1122
1123    @property
1124    def process_group(self):
1125        return dist.distributed_c10d._get_default_group()
1126
1127    @property
1128    def init_method(self):
1129        return f"{FILE_SCHEMA}{self.file_name}"
1130
1131    def _check_cpu_offload(self, fsdp_model, cpu_offload):
1132        self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
1133
1134    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
1135        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
1136
1137    def _check_forward_prefetch(self, fsdp_model, forward_prefetch):
1138        self.assertEqual(forward_prefetch, fsdp_model.forward_prefetch)
1139
1140    def run_subtests(self, *args, **kwargs):
1141        return run_subtests(self, *args, **kwargs)
1142
1143    @classmethod
1144    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
1145        self = cls(test_name)
1146        self.rank = rank
1147        self.file_name = file_name
1148        fake_pg = kwargs.get("fake_pg", False)
1149
1150        print(f"dist init r={self.rank}, world={self.world_size}")
1151
1152        # Specify gloo backend to make 'init_process_group()' succeed,
1153        # Actual tests will be skipped if there is no enough GPUs.
1154        backend = "nccl" if torch.cuda.is_available() else "gloo"
1155
1156        try:
1157            if fake_pg:
1158                store = torch.testing._internal.distributed.fake_pg.FakeStore()
1159                dist.init_process_group(
1160                    backend="fake",
1161                    world_size=self.world_size,
1162                    rank=rank,
1163                    store=store,
1164                )
1165            else:
1166                dist.init_process_group(
1167                    init_method=self.init_method,
1168                    backend=backend,
1169                    world_size=int(self.world_size),
1170                    rank=self.rank,
1171                )
1172        except RuntimeError as e:
1173            if "recompile" in e.args[0]:
1174                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)
1175
1176            raise
1177
1178        device_ids = None
1179        if torch.cuda.is_available() and torch.cuda.device_count():
1180            device_id = self.rank % torch.cuda.device_count()
1181            torch.cuda.set_device(device_id)
1182            device_ids = [device_id]
1183
1184        # Execute barrier prior to running test to ensure that every process
1185        # has finished initialization and that the following test
1186        # immediately exiting due to a skip doesn't cause flakiness.
1187        dist.barrier(device_ids=device_ids)
1188
1189        torch._dynamo.reset()
1190        self.run_test(test_name, pipe)
1191        torch._dynamo.reset()
1192
1193        dist.barrier(device_ids=device_ids)
1194
1195        dist.destroy_process_group()
1196
1197    def _train_for_several_steps(
1198        self,
1199        model: nn.Module,
1200        num_steps: int,
1201        autocast: bool,
1202        lr: float = 0.01,
1203        fsdp_cpu_offload: Optional[CPUOffload] = None,
1204        save_model: bool = False,
1205        mixed_precision: Optional[MixedPrecision] = None,
1206        enable_sharded_grad_scaler: bool = False,
1207        use_pure_fp16: bool = False,
1208        sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
1209    ):
1210        cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
1211
1212        model_device = next(model.parameters()).device
1213        if sharded_grad_scaler_kwargs is None:
1214            sharded_grad_scaler_kwargs = {}
1215        sharded_grad_scaler = ShardedGradScaler(
1216            enabled=enable_sharded_grad_scaler, **sharded_grad_scaler_kwargs
1217        )
1218        # use SGD with momentum instead of Adam, since Adam is scale invariant
1219        # and this makes it bad for tests
1220        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
1221        for _ in range(num_steps):
1222            optim.zero_grad()
1223            with torch.amp.autocast("cuda", enabled=autocast):
1224                # Inputs always cuda regardless of cpu offloading, or model.device
1225                input = model.module.get_input(torch.device("cuda"))
1226                if use_pure_fp16 or (mixed_precision and not isinstance(model, FSDP)):
1227                    if isinstance(input, torch.Tensor):
1228                        input = input.half()
1229                    else:
1230                        input = tuple(x.half() for x in input)
1231                output = model(*input)
1232                # Post-forward, if CPU offloading model param should be on CPU.
1233                if (
1234                    cpu_offload_params
1235                    and isinstance(model, FSDP)
1236                    # If not resharding after forward, the parameters are still
1237                    # exposed as unsharded views into the GPU flat parameter
1238                    and model.sharding_strategy
1239                    not in NO_RESHARD_AFTER_FORWARD_STRATEGIES
1240                ):
1241                    for p in model.parameters():
1242                        # Params should always be on CPU
1243                        self.assertEqual(p.device, torch.device("cpu"))
1244
1245                loss = model.module.get_loss(input, output).to(model_device)
1246            loss = sharded_grad_scaler.scale(loss)
1247
1248            if not mixed_precision and not use_pure_fp16:
1249                assert (
1250                    loss.dtype == torch.float32
1251                ), "loss data type should be float32, as the original \
1252                    parameter data type is float32."
1253            else:
1254                if use_pure_fp16:
1255                    self.assertEqual(loss.dtype, torch.float16)
1256                # FSDP loss is fp16, DDP AMP loss is fp32
1257                elif isinstance(model, FSDP):
1258                    assert mixed_precision is not None  # mypy
1259                    self.assertEqual(loss.dtype, mixed_precision.param_dtype)
1260                else:
1261                    self.assertEqual(loss.dtype, torch.float32)
1262            model.module.run_backward(loss)
1263            # Post-backward, if CPU offloading model params should be on CPU.
1264            if cpu_offload_params and isinstance(model, FSDP):
1265                for p in model.parameters():
1266                    # Params should always be on CPU
1267                    self.assertEqual(p.device, torch.device("cpu"))
1268            # Unscale the gradients and step
1269            sharded_grad_scaler.step(optim)
1270            # Update the scale factor
1271            sharded_grad_scaler.update()
1272            # if save_model, simulate save + load.
1273            if save_model:
1274                state_dict = {k: v.clone() for k, v in model.state_dict().items()}
1275                # Zero params, if save/load state_dict did not work properly, this
1276                # would break the parity test with DDP.
1277                _zero_model(model)
1278                model.load_state_dict(state_dict)
1279
1280        if isinstance(model, FSDP):
1281            model._assert_state(TrainingState.IDLE)
1282        return loss.detach()  # type: ignore[possibly-undefined]
1283
1284    def _test_fsdp_parity(
1285        self,
1286        model_class: Type[FSDPTestModel],
1287        fsdp_init_mode: FSDPInitMode,
1288        cuda_init_mode: CUDAInitMode,
1289        ref_init_fn: Optional[Callable] = None,
1290        num_iters: int = 2,
1291        save_model: bool = True,
1292        cpu_offload: CPUOffload = CPUOffload(),
1293        backward_prefetch: Optional[BackwardPrefetch] = None,
1294        sharding_strategy: Optional[ShardingStrategy] = None,
1295        mixed_precision: Optional[MixedPrecision] = None,
1296        forward_prefetch: bool = False,
1297        use_orig_params: bool = False,
1298        enable_sharded_grad_scaler: bool = False,
1299        use_pure_fp16: bool = False,
1300        init_kwargs: Optional[Dict[str, Any]] = None,
1301        sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
1302        **fsdp_kwargs,
1303    ):
1304        """
1305        Tests FSDP training against a reference, which defaults to DDP but
1306        may be customized with ``ref_init_fn``.
1307
1308        Args:
1309            model_class (Type[FSDPTestModel]): A model class that inherits from
1310                ``FSDPTestModel``, which defines the expected interface.
1311            fsdp_init_mode (FSDPInitMode): The mode to initialize the
1312                FSDP-wrapped model. This should not be ``NO_FSDP``.
1313            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
1314                non-wrapped model to construct the reference model, where this
1315                wrapper should provide data parallel semantics. If ``None``,
1316                then the callable defaults to the DDP constructor.
1317        """
1318        assert (
1319            fsdp_init_mode != FSDPInitMode.NO_FSDP
1320        ), "Expects an FSDP init mode that wraps with FSDP"
1321        if init_kwargs is None:
1322            init_kwargs = {}
1323        lr = 1e-2
1324        rank = self.process_group.rank()
1325        # Establish reference behavior with DDP
1326        model = model_class.init(
1327            self.process_group,
1328            FSDPInitMode.NO_FSDP,
1329            CUDAInitMode.CUDA_BEFORE,
1330            deterministic=True,
1331            **init_kwargs,
1332        )
1333        if ref_init_fn is None:
1334            ref_model = DDP(model, device_ids=[rank], output_device=rank)
1335        else:
1336            ref_model = ref_init_fn(model)
1337        if use_pure_fp16:
1338            ref_model = ref_model.half()
1339        ref_loss = self._train_for_several_steps(
1340            ref_model,
1341            num_iters,
1342            autocast=mixed_precision is not None,
1343            lr=lr,
1344            fsdp_cpu_offload=cpu_offload,
1345            mixed_precision=mixed_precision,
1346            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
1347            use_pure_fp16=use_pure_fp16,
1348            sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
1349        )
1350        ddp_params = list(ref_model.parameters())
1351        # Check against FSDP behavior
1352        fsdp_kwargs.update(
1353            {
1354                "cpu_offload": cpu_offload,
1355                "backward_prefetch": backward_prefetch,
1356                "sharding_strategy": sharding_strategy,
1357                "mixed_precision": mixed_precision,
1358                "forward_prefetch": forward_prefetch,
1359                "use_orig_params": use_orig_params,
1360            }
1361        )
1362        try:
1363            fsdp_model = model_class.init(
1364                self.process_group,
1365                fsdp_init_mode,
1366                cuda_init_mode,
1367                fsdp_kwargs,
1368                deterministic=True,
1369                **init_kwargs,
1370            )
1371        except Exception as e:
1372            raise ValueError(f"Initializing {model_class} raised error {str(e)}") from e
1373        if not isinstance(fsdp_model, FSDP):
1374            # Enforce that we wrap with top-level FSDP since we are comparing
1375            # assuming a data parallel reference and some test models may not
1376            # do so in their `init()` method
1377            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
1378        if use_pure_fp16:
1379            # Change the model parameter dtype after FSDP initialization
1380            fsdp_model = fsdp_model.half()
1381        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
1382            fsdp_model = fsdp_model.cuda()
1383        offload_params = cpu_offload is not None and cpu_offload.offload_params
1384        # Offloading parameters with `CUDA_AFTER` should raise an error during
1385        # lazy initialization due to the parameter devices not being CPU;
1386        # otherwise, all parameter devices should be CPU
1387        expects_device_error = (
1388            offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
1389        )
1390        expects_cpu_device = (
1391            offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
1392        )
1393        if expects_cpu_device:
1394            cpu_device = torch.device("cpu")
1395            for param in fsdp_model.parameters():
1396                self.assertEqual(param.device, cpu_device)
1397        context = (
1398            self.assertRaisesRegex(
1399                RuntimeError,
1400                "An FSDP-managed module with parameter CPU offloading enabled "
1401                "has parameters on cuda",
1402            )
1403            if expects_device_error
1404            else nullcontext()
1405        )
1406        with context:
1407            fsdp_loss = self._train_for_several_steps(
1408                fsdp_model,
1409                num_iters,
1410                autocast=False,
1411                lr=lr,
1412                fsdp_cpu_offload=cpu_offload,
1413                save_model=save_model,
1414                mixed_precision=mixed_precision,
1415                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
1416                use_pure_fp16=use_pure_fp16,
1417                sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
1418            )
1419        # No need to check for parameter and loss parity if expecting an error
1420        if expects_device_error:
1421            return
1422        # Check parameter devices are CPU if offloading to CPU before calling
1423        # `get_full_params()`, which will cast the parameters to FP32
1424        if offload_params:
1425            cpu_device = torch.device("cpu")
1426            for param in fsdp_model.parameters():
1427                self.assertEqual(param.device, cpu_device)
1428            fsdp_loss = fsdp_loss.cuda()
1429        fsdp_unsharded_params = get_full_params(fsdp_model)
1430        # Do not check dtype since the reference DDP loss may not be the same
1431        # dtype as the FSDP loss in the case of mixed precision
1432        torch.testing.assert_close(ref_loss, fsdp_loss, check_dtype=False)
1433        # Do not check for parameter parity if using mixed precision since (1)
1434        # the DDP parameters are in FP16 (from `half()`) while the FSDP
1435        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
1436        # the optimizer in FP16 while FSDP runs it in FP32
1437        # TODO: Disable checking the parameters for pure FP16 due to floating
1438        # point inaccuracy. Note that this means that the backward pass is not
1439        # checked: https://github.com/pytorch/pytorch/issues/90784
1440        if mixed_precision is None and not use_pure_fp16:
1441            self.assertEqual(
1442                ddp_params,
1443                fsdp_unsharded_params,
1444                exact_device=True,
1445                msg="FSDP did not match DDP",
1446            )
1447
1448
1449def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
1450    def fully_shard_with_compiled_compute(*args, **kwargs):
1451        torch.distributed._composable.fsdp.fully_shard(*args, **kwargs)  # type: ignore[operator]
1452        if compile_compute_on_module is None or isinstance(
1453            args[0], compile_compute_on_module
1454        ):
1455            args[0].compile()
1456
1457    class FullyShardMode(Enum):
1458        EAGER = auto()
1459        COMPILED_COMPUTE = auto()
1460
1461    def decorator(func):
1462        @wraps(func)
1463        def wrapper(*args, **kwargs):
1464            original_fully_shard = torch.distributed._composable.fsdp.fully_shard
1465            for mode in FullyShardMode:
1466                if mode != FullyShardMode.EAGER and not has_triton():
1467                    warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
1468                    continue
1469                # barrier to ensure thread reading the same value
1470                original_skip_fsdp_hooks = torch._dynamo.config.skip_fsdp_hooks
1471                original_compile_threads = torch._inductor.config.compile_threads
1472                torch.distributed.barrier()
1473
1474                if mode == FullyShardMode.EAGER:
1475                    fully_shard_patch = original_fully_shard
1476                elif mode == FullyShardMode.COMPILED_COMPUTE:
1477                    torch._dynamo.config.skip_fsdp_hooks = True
1478                    torch._inductor.config.compile_threads = 1
1479                    fully_shard_patch = fully_shard_with_compiled_compute  # type: ignore[assignment]
1480                else:
1481                    raise NotImplementedError(
1482                        f"Need to implement FullyShardMode={mode}"
1483                    )
1484
1485                # fully_shard is imported as a global
1486                # through `from ... import fully_shard`
1487                func.__globals__[original_fully_shard.__name__] = fully_shard_patch
1488                func(*args, **kwargs)
1489                # other threads use patched func before this thread restores
1490                torch.distributed.barrier()
1491                func.__globals__[original_fully_shard.__name__] = original_fully_shard
1492                torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
1493                torch._inductor.config.compile_threads = original_compile_threads
1494
1495        return wrapper
1496
1497    return decorator
1498
1499
1500class SkipModule(nn.Module):
1501    def __init__(self) -> None:
1502        super().__init__()
1503        self.lin = nn.Linear(10, 10, bias=False)
1504
1505    def forward(self, x):
1506        return self.lin(x)
1507
1508
1509class NestedLinear(nn.Module):
1510    def __init__(self, fsdp_wrap):
1511        super().__init__()
1512        if fsdp_wrap:
1513            self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda())
1514        else:
1515            self.nested_linear = nn.Linear(10, 10, bias=False).cuda()
1516
1517    def forward(self, x):
1518        return self.nested_linear(x)
1519
1520
1521class SkipModel(nn.Module):
1522    def __init__(self, double_nest):
1523        super().__init__()
1524        self.linear = nn.Linear(10, 10, bias=False).cuda()
1525        self.linear_skip = SkipModule().cuda()
1526        self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
1527
1528    def forward(self, x):
1529        x = self.linear(x)
1530        x = self.linear_skip(x)
1531        x = self.nested_linear(x)
1532        return x
1533