xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_mixed_precision.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import itertools
5import os
6import sys
7from functools import partial
8from itertools import product
9from typing import Any, Dict, List
10
11import torch
12import torch.cuda.nccl as nccl
13import torch.nn as nn
14import torch.nn.functional as F
15from torch import distributed as dist
16from torch.distributed._composable import fully_shard
17from torch.distributed.fsdp import (
18    BackwardPrefetch,
19    CPUOffload,
20    FullyShardedDataParallel as FSDP,
21    MixedPrecision,
22    ShardingStrategy,
23)
24from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
25from torch.distributed.fsdp.wrap import ModuleWrapPolicy, size_based_auto_wrap_policy
26from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
27from torch.nn.modules.batchnorm import _BatchNorm
28from torch.optim.swa_utils import AveragedModel
29from torch.testing._internal.common_distributed import (
30    SaveForwardInputsModel,
31    skip_if_lt_x_gpu,
32)
33from torch.testing._internal.common_fsdp import (
34    CUDAInitMode,
35    FSDPInitMode,
36    FSDPTest,
37    subtest_name,
38    TransformerWithSharedParams,
39)
40from torch.testing._internal.common_utils import (
41    instantiate_parametrized_tests,
42    parametrize,
43    run_tests,
44    skip_but_pass_in_sandcastle_if,
45    TEST_WITH_DEV_DBG_ASAN,
46)
47
48
49try:
50    import torchvision
51
52    HAS_TORCHVISION = True
53except ImportError:
54    HAS_TORCHVISION = False
55
56skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
57    not HAS_TORCHVISION, "no torchvision"
58)
59
60
61if not dist.is_available():
62    print("Distributed not available, skipping tests", file=sys.stderr)
63    sys.exit(0)
64
65if TEST_WITH_DEV_DBG_ASAN:
66    print(
67        "Skip dev-asan as torch + multiprocessing spawn have known issues",
68        file=sys.stderr,
69    )
70    sys.exit(0)
71
72# Various mixed precision configs to test under.
73default_mp = MixedPrecision(
74    param_dtype=torch.float16,
75    buffer_dtype=torch.float16,
76    reduce_dtype=torch.float16,
77)
78
79# Params and buffers are not cast, comm only happens
80# in reduced precision.
81mp_only_reduce = MixedPrecision(reduce_dtype=torch.float16)
82
83# Only parameters are cast (thus comm should happen in the param_dtype precision)
84mp_only_param_and_buf = MixedPrecision(
85    param_dtype=torch.float16, buffer_dtype=torch.float16
86)
87
88# Nothing is cast (thus param, comm, grad, and buffer should be in the full precision)
89mp_no_mixed_precision = MixedPrecision()
90
91nccl_supports_bf16 = dist.is_nccl_available() and nccl.version() >= (2, 10)
92
93mp_configs = [default_mp, mp_only_reduce, mp_only_param_and_buf, mp_no_mixed_precision]
94if nccl_supports_bf16:
95    mp_diff_buffer_and_reduce = MixedPrecision(
96        param_dtype=torch.float16,
97        buffer_dtype=torch.bfloat16,
98        reduce_dtype=torch.float32,
99    )
100    mp_configs.extend([mp_diff_buffer_and_reduce])
101
102# Buffer original dtype, which can differ from model params.
103_BUFFER_ORIG_DTYPE = torch.float64
104
105params = "mp_config,cpu_offload,full_precision_param_dtype,enable_sharded_grad_scaler"
106cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
107full_precision_param_dtype_config = [torch.float32, torch.float64]
108enable_sharded_grad_scaler = ["enable_sharded_grad_scaler", None]
109
110configs = list(
111    product(
112        mp_configs,
113        cpu_offload_config,
114        full_precision_param_dtype_config,
115        enable_sharded_grad_scaler,
116    )
117)
118
119test_name_mapping = {
120    str(CPUOffload(offload_params=True)): "offload_true",
121    str(CPUOffload(offload_params=False)): "offload_false",
122    str(default_mp): "mp_fp16",
123    str(mp_only_reduce): "mp_only_reduce",
124    str(mp_only_param_and_buf): "mp_only_param_and_buf",
125    str(mp_no_mixed_precision): "mp_no_mp",
126    str(torch.float32): "fp32",
127    str(torch.float64): "fp64",
128    "enable_sharded_grad_scaler": "enable_sharded_grad_scaler",
129}
130
131if nccl_supports_bf16:
132    test_name_mapping.update(
133        {
134            str(mp_diff_buffer_and_reduce): "mp_diff_buffer_reduce",
135        }
136    )
137
138subtest_name = partial(subtest_name, test_name_mapping)
139
140_CURRENT_FULL_PRECISION_PARAM_DTYPE = None
141
142
143@contextlib.contextmanager
144def patch_reduce_scatter(new_reduce_scatter, full_precision_param_dtype):
145    """
146    Patches ``dist.reduce_scatter_tensor`` with ``new_reduce_scatter`` and
147    restores upon exiting. Used for validation of mixed precision.
148    """
149    orig_reduce_scatter = dist.reduce_scatter_tensor
150    dist.reduce_scatter_tensor = new_reduce_scatter
151    global _CURRENT_FULL_PRECISION_PARAM_DTYPE
152    _CURRENT_FULL_PRECISION_PARAM_DTYPE = full_precision_param_dtype
153    try:
154        yield
155    finally:
156        dist.reduce_scatter_tensor = orig_reduce_scatter
157        _CURRENT_FULL_PRECISION_PARAM_DTYPE = None
158
159
160class LinearMixedPrecision(nn.Module):
161    """
162    A linear module with extra checks for mixed precision training.
163    """
164
165    def __init__(self, param_dtype, buffer_name="buffer", run_checks=True):
166        super().__init__()
167        self.lin = nn.Linear(10, 10, bias=False).to(param_dtype)
168        # Use a configurable buffer name to avoid all submodules sharing the
169        # same buffer name, which may hide prefixed vs. unprefixed name bugs
170        self.buffer_name = buffer_name
171        self.register_buffer(buffer_name, torch.randn((1, 2), dtype=_BUFFER_ORIG_DTYPE))
172        self._orig_param_type = param_dtype
173        self._orig_buffer_dtype = _BUFFER_ORIG_DTYPE
174        self.run_checks = run_checks
175
176    def forward(self, tup):
177        inp, cls, fsdp, mp_config, full_precision_param_dtype = tup
178        if self.run_checks:
179            # Param and input should be the mixed precision type
180            expected_param_type = (
181                mp_config.param_dtype
182                if mp_config.param_dtype is not None
183                else self._orig_param_type
184            )
185            expected_buffer_type = (
186                mp_config.buffer_dtype
187                if mp_config.buffer_dtype is not None
188                else self._orig_buffer_dtype
189            )
190            cls.assertEqual(inp.dtype, expected_param_type)
191            # Buffer should be in specified precision as well.
192            cls.assertEqual(getattr(self, self.buffer_name).dtype, expected_buffer_type)
193
194            # In FSDP, self.params should point to the right type.
195            num_active_fsdp = 0
196            for fsdp_module in FSDP.fsdp_modules(fsdp):
197                fsdp_managed_params = fsdp_module.params
198                # Single param assumption
199                cls.assertEqual(1, len(fsdp_managed_params))
200                for param in fsdp_managed_params:
201                    # FSDP unit is currently active if it is not using the param
202                    # local shard. This supports both FULL_SHARD and SHARD_GRAD_OP
203                    # cases. In FULL_SHARD, we have the additional property that
204                    # param._full_param_padded has not been freed.
205                    param_is_sharded = (
206                        fsdp_module.sharding_strategy != ShardingStrategy.NO_SHARD
207                        and fsdp_module.world_size > 1
208                    )
209                    is_fsdp_unit_active = (
210                        param_is_sharded
211                        and param.data.data_ptr() != param._local_shard.data_ptr()
212                    )
213                    if is_fsdp_unit_active:
214                        num_active_fsdp += 1
215                        # This FSDP unit is active, verify param points to mixed
216                        cls.assertEqual(param.dtype, expected_param_type)
217                        # _unshard should have also freed the fp16 shard.
218                        # Shard is never allocated if param_dtype mixed precision is not
219                        # enabled.
220                        if mp_config.param_dtype is not None:
221                            cls.assertEqual(0, param._mp_shard.untyped_storage().size())
222                        else:
223                            cls.assertFalse(hasattr(param, "_mp_shard"))
224                    elif param_is_sharded:
225                        # This FSDP unit is not active as full param has been
226                        # freed or not yet allocated. Ensure param points to full
227                        # precision param.
228                        cls.assertEqual(param.dtype, full_precision_param_dtype)
229            # We should have gotten at least one active FSDP unit for sharded
230            # (world size > 1) cases. For cases where param is not sharded
231            # (ie world_size == 1) it is a bit hard to check if FSDP unit is active
232            # as we'd always point to the local shard, so we rely on the forward
233            # pass self.lin(inp) working well and inp being reduced precision to
234            # implicitly validate that the param is indeed in the reduced precision.
235            if cls.world_size > 1:
236                cls.assertGreater(num_active_fsdp, 0)
237
238        return (self.lin(inp), cls, fsdp, mp_config, full_precision_param_dtype)
239
240
241class TestFSDPMixedPrecision(FSDPTest):
242    @property
243    def world_size(self):
244        raise ValueError("To be implemented by child classes")
245
246    def _get_simple_nested_model(
247        self, param_dtype, run_checks, *fsdp_args, **fsdp_kwargs
248    ):
249        model = FSDP(
250            nn.Sequential(
251                FSDP(
252                    LinearMixedPrecision(
253                        param_dtype, buffer_name="buffer0", run_checks=run_checks
254                    ).cuda(),
255                    *fsdp_args,
256                    **fsdp_kwargs,
257                ),
258                LinearMixedPrecision(
259                    param_dtype, buffer_name="buffer1", run_checks=run_checks
260                ).cuda(),
261            ),
262            *fsdp_args,
263            **fsdp_kwargs,
264        )
265        return model
266
267    def _get_simple_nested_model_composable(
268        self, param_dtype, run_checks, *fsdp_args, **fsdp_kwargs
269    ):
270        model = nn.Sequential(
271            LinearMixedPrecision(
272                param_dtype, buffer_name="buffer0", run_checks=run_checks
273            ).cuda(),
274            LinearMixedPrecision(
275                param_dtype, buffer_name="buffer1", run_checks=run_checks
276            ).cuda(),
277        )
278        fully_shard(model[0], *fsdp_args, **fsdp_kwargs)
279        fully_shard(model, *fsdp_args, **fsdp_kwargs)
280        return model
281
282    def _get_simple_model(self, param_dtype, *fsdp_args, **fsdp_kwargs):
283        model = FSDP(
284            LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs
285        )
286        return model
287
288    def _validate_no_mp_shard(self, fsdp_model):
289        """
290        Validates that there is no mixed precision _mp_shard allocated
291        when it is not expected to be.
292        """
293        fsdp_units = FSDP.fsdp_modules(fsdp_model)
294        for fsdp in fsdp_units:
295            for param in fsdp.params:
296                self.assertFalse(hasattr(param, "_mp_shard"))
297
298    def _validate_mp_shard_freed(self, fsdp_model):
299        """
300        Ensures that the mixed precision shard is greed for all FSDP units.
301        """
302        fsdp_units = FSDP.fsdp_modules(fsdp_model)
303        for fsdp in fsdp_units:
304            for param in fsdp.params:
305                self.assertEqual(0, param._mp_shard.untyped_storage().size())
306
307    def _reduce_scatter_validate_mp(
308        self, orig_reduce_scatter, mp_config, should_run_low_prec, *args, **kwargs
309    ):
310        """
311        Runs reduce-scatter but verifies mixed precision settings before. This
312        is to test mixed precision is working as expected during backward pass.
313        In particular it ensures that the gradients were cast to the right type
314        and comm. is going to happen in the right type.
315        """
316        tensors = []
317        for x in args:
318            if isinstance(x, torch.Tensor):
319                tensors.append(x)
320        for x in kwargs.values():
321            if isinstance(x, torch.Tensor):
322                tensors.append(x)
323
324        # reduce_dtype has higher priority than param_dtype, because mixed_precision
325        # supports overriding param_dtype with reduce_dtype to control the
326        # reduction precision. In the case where reduce_dtype == param_dtype
327        # this tests that gradients are in the expected precision as well.
328        # If reduce_dtype is not specified (is None) we comm. in the param_dtype
329        # if that is specified, otherwise full precision dtype.
330        if should_run_low_prec:
331            expected_dtype = (
332                mp_config.reduce_dtype
333                if mp_config.reduce_dtype is not None
334                else (
335                    mp_config.param_dtype
336                    if mp_config.param_dtype is not None
337                    else _CURRENT_FULL_PRECISION_PARAM_DTYPE
338                )
339            )
340        else:
341            expected_dtype = _CURRENT_FULL_PRECISION_PARAM_DTYPE
342
343        for t in tensors:
344            self.assertEqual(
345                expected_dtype,
346                t.dtype,
347                f"Expected to reduce in {expected_dtype} but got tensors in {t.dtype}",
348            )
349
350        return orig_reduce_scatter(*args, **kwargs)
351
352    def _test_grads_reduced_precision(
353        self, offload_params: bool, use_orig_params: bool
354    ):
355        class MyModel(nn.Module):
356            def __init__(self) -> None:
357                super().__init__()
358                self.lin1 = nn.Linear(10, 10)
359                self.lin2 = nn.Linear(10, 10)
360
361            def forward(self, x):
362                return self.lin2(self.lin1(x))
363
364        m = MyModel().cuda()
365        mp = MixedPrecision(
366            param_dtype=torch.float16,
367            reduce_dtype=torch.float16,
368            buffer_dtype=torch.float16,
369            keep_low_precision_grads=True,
370        )
371        fsdp_kwargs = {
372            "mixed_precision": mp,
373            "cpu_offload": CPUOffload(offload_params=offload_params),
374            "use_orig_params": use_orig_params,
375        }
376        m.lin1 = FSDP(m.lin1, **fsdp_kwargs)
377        m = FSDP(m, **fsdp_kwargs)
378        for _ in range(6):
379            inp = torch.ones(1, 10)
380            m(inp).sum().backward()
381            for param in m.parameters():
382                if param.grad is not None:
383                    self.assertEqual(torch.float16, param.grad.dtype)
384
385        dist.barrier()
386
387    def _run_test_mixed_precision_e2e(
388        self,
389        mp_config,
390        cpu_offload,
391        backward_prefetch,
392        forward_prefetch,
393        full_precision_param_dtype,
394        sharding_strategy,
395        enable_sharded_grad_scaler,
396    ):
397        torch.cuda.set_device(self.rank)
398        fsdp_models = [
399            self._get_simple_model(
400                param_dtype=full_precision_param_dtype,
401                sharding_strategy=sharding_strategy,
402                cpu_offload=cpu_offload,
403                mixed_precision=mp_config,
404                backward_prefetch=backward_prefetch,
405                forward_prefetch=forward_prefetch,
406            ),
407            self._get_simple_nested_model(
408                param_dtype=full_precision_param_dtype,
409                run_checks=True,
410                sharding_strategy=sharding_strategy,
411                cpu_offload=cpu_offload,
412                mixed_precision=mp_config,
413                backward_prefetch=backward_prefetch,
414                forward_prefetch=forward_prefetch,
415            ),
416        ]
417        for model in fsdp_models:
418            if not cpu_offload.offload_params:
419                model.cuda()
420
421            # Patch reduce_scatter to add validation for mixed precision types.
422            orig_reduce_scatter = dist.reduce_scatter_tensor
423            test_reduce_scatter = partial(
424                self._reduce_scatter_validate_mp,
425                orig_reduce_scatter,
426                mp_config,
427                True,
428            )
429            with patch_reduce_scatter(test_reduce_scatter, full_precision_param_dtype):
430                scaler = ShardedGradScaler(enabled=enable_sharded_grad_scaler)
431                optim = torch.optim.Adam(model.parameters())
432
433                for _ in range(3):
434                    inp = torch.randn(
435                        3, 10, device="cuda", dtype=full_precision_param_dtype
436                    )
437                    # Forward pass of LinearMixedPrecision check casting of
438                    # inputs, params, buffers.
439                    act, *_ = model(
440                        (inp, self, model, mp_config, full_precision_param_dtype)
441                    )
442                    # Buffers should be casted.
443                    for buf in model.buffers():
444                        if mp_config.buffer_dtype is not None:
445                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
446                        else:
447                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)
448                    # p._mp_shard should be freed.
449                    if mp_config.param_dtype is not None:
450                        self._validate_mp_shard_freed(model)
451                    else:
452                        # We never should have allocated an _mp_shard.
453                        self._validate_no_mp_shard(model)
454
455                    loss = act.sum()
456                    loss = scaler.scale(loss)
457                    if mp_config.param_dtype is not None:
458                        self.assertEqual(loss.dtype, mp_config.param_dtype)
459                    else:
460                        self.assertEqual(loss.dtype, full_precision_param_dtype)
461                    # Will run patched reduce scatter that validates mixed_precision
462                    # types in backward.
463                    loss.backward()
464                    # Buffers stay casted even after backwards.
465                    for buf in model.buffers():
466                        if mp_config.buffer_dtype is not None:
467                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
468                        else:
469                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)
470                    # p._mp_shard should be freed.
471                    if mp_config.param_dtype is not None:
472                        self._validate_mp_shard_freed(model)
473                    else:
474                        self._validate_no_mp_shard(model)
475
476                    # Ensure params and grads are in full precision,
477                    # as after fwd/backward we maintain full precision shards.
478                    for param in model.parameters():
479                        self.assertEqual(param.dtype, full_precision_param_dtype)
480                        if param.grad is not None:
481                            self.assertEqual(
482                                param.grad.dtype, full_precision_param_dtype
483                            )
484
485                    # Unscale the gradients and step
486                    scaler.step(optim)
487                    # Update the scale factor
488                    scaler.update()
489
490                    # Summon full params should be in full precision
491                    with model.summon_full_params(model):
492                        # It is not expected for summon_full_params to allocate
493                        # a mixed precision shard.
494                        if mp_config.param_dtype is not None:
495                            self._validate_mp_shard_freed(model)
496                        else:
497                            self._validate_no_mp_shard(model)
498                        params = list(model.parameters())
499                        for p in params:
500                            self.assertEqual(p.dtype, full_precision_param_dtype)
501
502                        # Note that buffers are cast only once and only restored
503                        # to the original buffer dtype in state_dict, so
504                        # summon_full_params is not expected to restore buffer
505                        # types to their original.
506                        named_buffers = dict(model.named_buffers())
507                        for v in named_buffers.values():
508                            if mp_config.buffer_dtype is not None:
509                                self.assertEqual(v.dtype, mp_config.buffer_dtype)
510                            else:
511                                self.assertEqual(v.dtype, _BUFFER_ORIG_DTYPE)
512
513                    # state_dict should be in full precision
514                    state_dict = {k: v.clone() for k, v in model.state_dict().items()}
515                    for name, tensor in state_dict.items():
516                        # Parameters and buffers are checkpointed in their
517                        # original dtypes, which may be different.
518                        if name in named_buffers.keys():
519                            self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE)
520                        else:
521                            self.assertEqual(
522                                tensor.dtype,
523                                full_precision_param_dtype,
524                                f"{name}: {tensor.dtype} vs {full_precision_param_dtype}",
525                            )
526
527                    # After state_dict, buffer's dtype should have been restored
528                    # to the mixed precision one.
529                    for buf in model.buffers():
530                        if mp_config.buffer_dtype is not None:
531                            self.assertEqual(buf.dtype, mp_config.buffer_dtype)
532                        else:
533                            self.assertEqual(buf.dtype, _BUFFER_ORIG_DTYPE)
534
535
536class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision):
537    @property
538    def world_size(self):
539        return 2
540
541    def _get_subtest_config(self) -> Dict[str, List[Any]]:
542        """Returns a subtest configuration that subtests prefetching settings
543        together."""
544        return {
545            "forward_prefetch": [False, True],
546            "backward_prefetch": [
547                None,
548                BackwardPrefetch.BACKWARD_PRE,
549                BackwardPrefetch.BACKWARD_POST,
550            ],
551        }
552
553    @skip_if_lt_x_gpu(2)
554    def test_mixed_precision_no_reshard_after_forward(self):
555        # Note that we don't exercise all possible different configs so as to
556        # not increase test TTS too much.
557        mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
558        self._run_test_mixed_precision_e2e(
559            mp_config=mp,
560            cpu_offload=CPUOffload(offload_params=True),
561            backward_prefetch=None,
562            forward_prefetch=False,
563            full_precision_param_dtype=torch.float64,
564            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
565            enable_sharded_grad_scaler=False,
566        )
567
568    @skip_if_lt_x_gpu(2)
569    @parametrize(params, configs, subtest_name)
570    def test_mixed_precision_e2e_full_shard(
571        self,
572        mp_config,
573        cpu_offload,
574        full_precision_param_dtype,
575        enable_sharded_grad_scaler,
576    ):
577        self.run_subtests(
578            self._get_subtest_config(),
579            self._run_test_mixed_precision_e2e,
580            mp_config=mp_config,
581            cpu_offload=cpu_offload,
582            full_precision_param_dtype=full_precision_param_dtype,
583            sharding_strategy=ShardingStrategy.FULL_SHARD,
584            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
585        )
586
587    def _test_mixed_precision_embedding_table(self, mp_config):
588        # Basic test to ensure int inputs are not casted which would break
589        # modules such as embedding tables.
590        param_dtype = mp_config.param_dtype or torch.float32
591        orig_reduce_scatter = dist.reduce_scatter_tensor
592        test_reduce_scatter = partial(
593            self._reduce_scatter_validate_mp,
594            orig_reduce_scatter,
595            mp_config,
596            True,
597        )
598        with patch_reduce_scatter(test_reduce_scatter, param_dtype):
599            # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the
600            # entire `TransformerWithSharedParams` with a single top-level FSDP
601            model = TransformerWithSharedParams.init(
602                self.process_group,
603                FSDPInitMode.NO_FSDP,
604                CUDAInitMode.CUDA_BEFORE,
605                {"mixed_precision": mp_config},
606            )
607            fsdp_model = FSDP(model, mixed_precision=mp_config)
608            optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)
609            for _ in range(6):
610                inp = fsdp_model.module.get_input(torch.device("cuda"))
611                # This would fail if we casted integer module inputs such as for
612                # embedding tables.
613                output = fsdp_model(*inp)
614                loss = fsdp_model.module.get_loss(inp, output).cuda()
615                self.assertEqual(loss.dtype, param_dtype)
616                fsdp_model.module.run_backward(loss)
617                optim.step()
618
619    @skip_if_lt_x_gpu(2)
620    def test_mp_embedding_reduce(self):
621        self._test_mixed_precision_embedding_table(
622            mp_config=MixedPrecision(reduce_dtype=torch.float16)
623        )
624
625    @skip_if_lt_x_gpu(2)
626    def test_mp_embedding_only_params_and_bufs(self):
627        self._test_mixed_precision_embedding_table(
628            mp_config=MixedPrecision(
629                param_dtype=torch.float16,
630                buffer_dtype=torch.float16,
631            )
632        )
633
634    @skip_if_lt_x_gpu(2)
635    def test_mp_embedding_default(self):
636        default_mp_config = MixedPrecision(
637            param_dtype=torch.float16,
638            buffer_dtype=torch.float16,
639            reduce_dtype=torch.float16,
640        )
641        self._test_mixed_precision_embedding_table(mp_config=default_mp_config)
642
643    @skip_if_lt_x_gpu(2)
644    def test_mp_embedding_params_and_reduce_diff(self):
645        params_and_reduce_different = MixedPrecision(
646            param_dtype=torch.float16,
647            reduce_dtype=torch.float32,
648            buffer_dtype=torch.float16,
649        )
650        self._test_mixed_precision_embedding_table(
651            mp_config=params_and_reduce_different
652        )
653
654    @skip_if_lt_x_gpu(2)
655    @skipIfNoTorchVision
656    def test_mixed_precision_resnet(self):
657        """
658        End to end test to ensure mixed precision + auto_wrap works
659        for ResNet model.
660        """
661        resnet_model = torchvision.models.resnet50().cuda()
662        resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
663            resnet_model, process_group=dist.distributed_c10d._get_default_group()
664        )
665        n_bn = sum(
666            1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()
667        )
668        inp = torch.ones(1, 3, 1000, 1000, device="cuda")
669        mp_config = MixedPrecision(
670            param_dtype=torch.float16,
671            reduce_dtype=torch.float16,
672            buffer_dtype=torch.float16,
673        )
674        fsdp = FSDP(
675            resnet_model,
676            auto_wrap_policy=size_based_auto_wrap_policy,
677            mixed_precision=mp_config,
678        )
679        # Batchnorm units should be wrapped individually. Validate this by
680        # ensuring there are equal no. of FSDP units that are BN as BN units
681        # in original resnet model.
682        fsdp_bn = 0
683        for module in fsdp.fsdp_modules(fsdp):
684            wrapped_module = module.module
685            if isinstance(wrapped_module, _BatchNorm):
686                fsdp_bn += 1
687
688        self.assertEqual(fsdp_bn, n_bn)
689        # Would throw type mismatch issue without mixed precision autowrapping.
690        loss = fsdp(inp).sum()
691        loss.backward()
692
693    @skip_if_lt_x_gpu(2)
694    def test_grads_reduced_precision(self):
695        self.run_subtests(
696            {
697                "offload_params": [False, True],
698                "use_orig_params": [False, True],
699            },
700            self._test_grads_reduced_precision,
701        )
702
703    @skip_if_lt_x_gpu(2)
704    @parametrize("convert_sync_bn", [True, False])
705    def test_mp_batchnorm(self, convert_sync_bn):
706        class BatchNormNet(nn.Module):
707            def __init__(self, affine=True):
708                super().__init__()
709                self.fc1 = nn.Linear(2, 40, bias=False)
710                self.bn = nn.BatchNorm1d(4, affine=affine)
711                self.fc2 = nn.Linear(40, 4, bias=False)
712                self.ln = nn.LayerNorm(4)
713                self.fc3 = nn.Linear(4, 4, bias=False)
714
715            def forward(self, x):
716                x = torch.reshape(self.fc1(x), (-1, 4, 10))
717                x = self.bn(x)
718                x = torch.reshape(x, (-1, 40))
719                x = self.fc2(x)
720                x = self.ln(x)
721                x = self.fc3(x)
722                return F.softmax(x, dim=1)
723
724        def never_wrap_policy(*args, **kwargs):
725            return False
726
727        net = BatchNormNet().cuda()
728        if convert_sync_bn:
729            net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
730        # FSDP detects that mixed precision + batchnorm will cause issues
731        # and thus wrap batchnorm in a distinct FSDP unit that does not
732        # use mixed precision.
733        mp_config = MixedPrecision(
734            param_dtype=torch.float16,
735            reduce_dtype=torch.float16,
736            buffer_dtype=torch.float16,
737            _module_classes_to_ignore=[_BatchNorm, nn.LayerNorm],
738        )
739        with self.assertWarnsRegex(
740            expected_warning=UserWarning,
741            expected_regex="These modules will be wrapped as separate FSDP",
742        ):
743            model = FSDP(
744                net,
745                mixed_precision=mp_config,
746                auto_wrap_policy=never_wrap_policy,
747            )
748
749        no_mp = MixedPrecision()
750        for mod in [model.ln, model.bn]:
751            self.assertTrue(isinstance(mod, FSDP))
752            self.assertEqual(no_mp, mod.mixed_precision)
753        # policy should not have wrapped any other submodules
754        for mod in [model.fc1, model.fc2, model.fc3]:
755            self.assertFalse(isinstance(mod, FSDP))
756
757        # Overall mixed precision is still enabled
758        self.assertEqual(mp_config, model.mixed_precision)
759
760        inp = torch.randn((1, 2), device="cuda")
761        # Without FSDP BN mixed precision fix, this would result in
762        # RuntimeError: Expected counts to have type Half but got Float
763        # for syncBN
764        model(inp).sum().backward()
765
766    @skip_if_lt_x_gpu(2)
767    def test_eval_root_cast_inputs(self):
768        """
769        In a case where root module does not manage FSDP parameters,
770        ensure that we don't cast forward inputs which could potentially
771        cause a dtype mismatch. Check that FSDP_USE_FULL_PREC_IN_EVAL controls
772        this.
773        """
774
775        low_prec_dtype = torch.float16
776
777        class MyModel(torch.nn.Module):
778            def __init__(self) -> None:
779                super().__init__()
780                self.a = nn.Linear(5, 5)
781
782            def forward(self, x, expect_use_full_prec_in_eval):
783                if expect_use_full_prec_in_eval:
784                    assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}"
785                else:
786                    assert (
787                        x.dtype == low_prec_dtype
788                    ), f"Expected {low_prec_dtype}, got {x.dtype}"
789                return self.a(x)
790
791        mp_config = MixedPrecision(
792            param_dtype=low_prec_dtype,
793            reduce_dtype=low_prec_dtype,
794            buffer_dtype=low_prec_dtype,
795        )
796
797        for use_full_prec_in_eval in [True, False]:
798            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = (
799                "1" if use_full_prec_in_eval else "0"
800            )
801            m = MyModel().cuda()
802            m.a = FSDP(m.a, mixed_precision=mp_config)
803            model = FSDP(m, mixed_precision=mp_config)
804            model.eval()
805            inp = torch.randn(5, 5)
806            model(inp, use_full_prec_in_eval).sum().backward()
807
808    @skip_if_lt_x_gpu(2)
809    def test_full_precision_in_eval(self):
810        """
811        Tests that eval runs in full precision if FSDP_USE_FULL_PREC_IN_EVAL is set.
812        """
813        for (
814            use_composable,
815            cast_forward_inputs,
816            use_full_prec_in_eval,
817        ) in itertools.product([True, False], [True, False], [True, False]):
818            mp_config = MixedPrecision(
819                param_dtype=torch.float16,
820                reduce_dtype=torch.float16,
821                buffer_dtype=torch.float16,
822                cast_forward_inputs=cast_forward_inputs,
823            )
824            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = (
825                "1" if use_full_prec_in_eval else "0"
826            )
827            model = TransformerWithSharedParams.init(
828                self.process_group,
829                FSDPInitMode.NO_FSDP if use_composable else FSDPInitMode.RECURSIVE,
830                CUDAInitMode.CUDA_BEFORE,
831                {"mixed_precision": mp_config},
832            )
833            if use_composable:
834                auto_wrap_policy = ModuleWrapPolicy(
835                    {
836                        TransformerEncoderLayer,
837                        TransformerDecoderLayer,
838                    }
839                )
840                fully_shard(model, policy=auto_wrap_policy, mixed_precision=mp_config)
841            module_accessor = model if use_composable else model
842            inp = module_accessor.get_input(torch.device("cuda"))
843            output = model(*inp)
844            loss = module_accessor.get_loss(inp, output).cuda()
845            # Loss should be in fp16
846            self.assertEqual(torch.float16, loss.dtype)
847            module_accessor.run_backward(loss)
848            # Grads should be in fp32 as we upcast them
849            for p in model.parameters():
850                if p.grad is not None:
851                    self.assertEqual(torch.float32, p.grad.dtype)
852
853            # Now in eval mode, loss should be fp32 if use_full_prec_in_eval is set.
854            model.eval()
855            inp = module_accessor.get_input(torch.device("cuda"))
856            output = model(*inp)
857            loss = module_accessor.get_loss(inp, output).cuda()
858            expected_dtype = torch.float32 if use_full_prec_in_eval else torch.float16
859            self.assertEqual(expected_dtype, loss.dtype)
860
861    @skip_if_lt_x_gpu(2)
862    def test_full_precision_in_eval_buffers(self):
863        """
864        Tests that when model.eval() and FSDP_USE_FULL_PREC_IN_EVAL is set,
865        buffers are in the full precision.
866        """
867        for (
868            use_composable,
869            cast_forward_inputs,
870            use_full_prec_in_eval,
871        ) in itertools.product([True, False], [True, False], [True, False]):
872            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = (
873                "1" if use_full_prec_in_eval else "0"
874            )
875            mp_config = MixedPrecision(
876                param_dtype=torch.float16,
877                reduce_dtype=torch.float16,
878                buffer_dtype=torch.float16,
879                cast_forward_inputs=cast_forward_inputs,
880            )
881            model_getter = (
882                self._get_simple_nested_model_composable
883                if use_composable
884                else self._get_simple_nested_model
885            )
886            fsdp_model = model_getter(
887                param_dtype=torch.float32,
888                run_checks=False,
889                mixed_precision=mp_config,
890            )
891
892            inp = torch.randn(3, 10, device="cuda")
893            fsdp_model((inp, self, fsdp_model, mp_config, torch.float32))
894            for buf in fsdp_model.buffers():
895                self.assertEqual(torch.float16, buf.dtype)
896
897            # model.eval() + forward pass should make the buffers in full prec again
898            # Add pre-forward hooks
899            def verify_eval_buffer_dtype(module, input):
900                expected_dtype = (
901                    _BUFFER_ORIG_DTYPE if use_full_prec_in_eval else torch.float16
902                )
903                for buf in module.buffers():
904                    self.assertEqual(expected_dtype, buf.dtype)
905
906            def _get_underlying_module(m):
907                return m.module if isinstance(m, FSDP) else m
908
909            hook_handles = []
910            hook_handles.append(
911                _get_underlying_module(fsdp_model[0]).register_forward_pre_hook(
912                    verify_eval_buffer_dtype
913                )
914            )
915            hook_handles.append(
916                _get_underlying_module(fsdp_model[1]).register_forward_pre_hook(
917                    verify_eval_buffer_dtype
918                )
919            )
920
921            fsdp_model.eval()
922            fsdp_model((inp, self, fsdp_model, mp_config, torch.float32))
923            for hook_handle in hook_handles:
924                hook_handle.remove()
925
926            expected_dtype = (
927                _BUFFER_ORIG_DTYPE if use_full_prec_in_eval else torch.float16
928            )
929            for buf in fsdp_model.buffers():
930                self.assertEqual(expected_dtype, buf.dtype)
931
932            # model.train() + forward again should make buffers in fp16
933            fsdp_model.train()
934            fsdp_model((inp, self, fsdp_model, mp_config, torch.float32))
935            for buf in fsdp_model.buffers():
936                self.assertEqual(torch.float16, buf.dtype)
937
938    @skip_if_lt_x_gpu(2)
939    def test_full_precision_in_eval_comm(self):
940        for (
941            use_composable,
942            cast_forward_inputs,
943            use_full_prec_in_eval,
944        ) in itertools.product([True, False], [True, False], [True, False]):
945            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = (
946                "1" if use_full_prec_in_eval else "0"
947            )
948            mp_config = MixedPrecision(
949                param_dtype=torch.float32,
950                reduce_dtype=torch.float16,
951                buffer_dtype=torch.float32,
952                cast_forward_inputs=cast_forward_inputs,
953                # cast reduction for batchnorm also just in this test, to make
954                # validation easier.
955                _module_classes_to_ignore=[],
956            )
957            model = TransformerWithSharedParams.init(
958                self.process_group,
959                FSDPInitMode.NO_FSDP if use_composable else FSDPInitMode.RECURSIVE,
960                CUDAInitMode.CUDA_BEFORE,
961                {"mixed_precision": mp_config},
962            )
963            if use_composable:
964                auto_wrap_policy = ModuleWrapPolicy(
965                    {
966                        TransformerEncoderLayer,
967                        TransformerDecoderLayer,
968                    }
969                )
970                fully_shard(model, policy=auto_wrap_policy, mixed_precision=mp_config)
971            model_accessor = model if use_composable else model.module
972            # Patch reduce_scatter to add validation for mixed precision types.
973            orig_reduce_scatter = dist.reduce_scatter_tensor
974            test_reduce_scatter = partial(
975                self._reduce_scatter_validate_mp,
976                orig_reduce_scatter,
977                mp_config,
978                not use_full_prec_in_eval,
979            )
980            model.eval()
981            with patch_reduce_scatter(test_reduce_scatter, torch.float32):
982                inp = model_accessor.get_input(torch.device("cuda"))
983                output = model(*inp)
984                loss = model_accessor.get_loss(inp, output).cuda()
985                model_accessor.run_backward(loss)
986
987    @skip_if_lt_x_gpu(2)
988    def test_input_grads_with_param_mixed_precision(self):
989        """
990        Tests that input tensors that require gradients do get their gradients
991        even after being cast to a low precision (when parameter mixed
992        precision is enabled).
993        """
994        self.run_subtests(
995            {
996                "sharding_strategy": [
997                    ShardingStrategy.FULL_SHARD,
998                    ShardingStrategy.SHARD_GRAD_OP,
999                    ShardingStrategy.NO_SHARD,
1000                ],
1001                "use_orig_params": [False, True],
1002            },
1003            self._test_input_grads_with_param_mixed_precision,
1004        )
1005
1006    def _test_input_grads_with_param_mixed_precision(
1007        self,
1008        sharding_strategy: ShardingStrategy,
1009        use_orig_params: bool,
1010    ):
1011        model = nn.Linear(1024, 1024, bias=False)
1012        mixed_precision = MixedPrecision(
1013            param_dtype=torch.float16,
1014            reduce_dtype=torch.float32,
1015            buffer_dtype=torch.float32,
1016        )
1017        fsdp_model = FSDP(
1018            model,
1019            sharding_strategy=sharding_strategy,
1020            mixed_precision=mixed_precision,
1021            device_id=torch.cuda.current_device(),
1022            use_orig_params=use_orig_params,
1023        )
1024        # Use an input with dtype not equal to the mixed precision
1025        # `param_dtype` so that it gets cast
1026        x_float = torch.randn(
1027            (32, 1024),
1028            device="cuda",
1029            dtype=torch.float32,
1030            requires_grad=True,
1031        )
1032        fsdp_model(x_float).sum().backward()
1033        self.assertTrue(x_float.grad is not None)
1034        # Check that `x_float` preserves its dtype, meaning that the gradient
1035        # propagated via `ToCopyBackward0`
1036        self.assertEqual(x_float.grad.dtype, torch.float32)
1037
1038
1039class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision):
1040    """
1041    Smaller test suite for unshared param (i.e. world_size == 1) case.
1042    """
1043
1044    @property
1045    def world_size(self):
1046        return 1
1047
1048    @skip_if_lt_x_gpu(1)
1049    def test_grads_reduced_precision(self):
1050        self.run_subtests(
1051            {"offload_params": [False, True], "use_orig_params": [False, True]},
1052            self._test_grads_reduced_precision,
1053        )
1054
1055    @skip_if_lt_x_gpu(1)
1056    def test_mixed_precision_no_reshard_after_forward(self):
1057        # Note that we don't exercise all possible different configs so as to
1058        # not increase test TTS too much.
1059        mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
1060        self._run_test_mixed_precision_e2e(
1061            mp_config=mp,
1062            cpu_offload=CPUOffload(offload_params=True),
1063            backward_prefetch=None,
1064            forward_prefetch=False,
1065            full_precision_param_dtype=torch.float64,
1066            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
1067            enable_sharded_grad_scaler=False,
1068        )
1069
1070    @skip_if_lt_x_gpu(1)
1071    def test_mixed_precision_e2e_full_shard(self):
1072        mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce
1073        self._run_test_mixed_precision_e2e(
1074            mp_config=mp,
1075            cpu_offload=CPUOffload(offload_params=True),
1076            backward_prefetch=None,
1077            forward_prefetch=False,
1078            full_precision_param_dtype=torch.float64,
1079            sharding_strategy=ShardingStrategy.FULL_SHARD,
1080            enable_sharded_grad_scaler=False,
1081        )
1082
1083
1084instantiate_parametrized_tests(TestFSDPMixedPrecisionSharded)
1085
1086
1087class IgnoredModule(nn.Module):
1088    def __init__(self) -> None:
1089        super().__init__()
1090        self.l = nn.Linear(100, 100)
1091
1092    def forward(self, x):
1093        return self.l(x)
1094
1095
1096class ModelWithIgnoredModule(nn.Module):
1097    def __init__(self) -> None:
1098        super().__init__()
1099        self.l1 = nn.Linear(100, 100)
1100        self.ignored = IgnoredModule()
1101        self.l2 = nn.Linear(100, 100)
1102
1103    def forward(self, x):
1104        return self.l2(self.ignored(self.l1(x)))
1105
1106
1107class TestFSDPMixedPrecisionIgnoredModules(FSDPTest):
1108    @property
1109    def world_size(self):
1110        return 1
1111
1112    @skip_if_lt_x_gpu(1)
1113    def test_mixed_precision_with_ignored_module(self):
1114        model = ModelWithIgnoredModule().cuda()
1115        float16 = MixedPrecision(param_dtype=torch.float16)
1116        model = FSDP(
1117            model,
1118            ignored_modules=[model.ignored],
1119            mixed_precision=float16,
1120        )
1121
1122        x = torch.ones(2, 100, device=torch.cuda.current_device())
1123
1124        with self.assertRaisesRegex(RuntimeError, "must have the same dtype"):
1125            model(x).sum().backward()
1126
1127
1128class TestFSDPDifferentSubmodulePrecision(FSDPTest):
1129    @property
1130    def world_size(self):
1131        return 2
1132
1133    @skip_if_lt_x_gpu(2)
1134    def test_float16_on_one_submodule(self):
1135        forward_inputs: Dict[str, nn.Module] = {}
1136        float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
1137
1138        model = SaveForwardInputsModel(
1139            forward_inputs,
1140            cast_forward_inputs=False,
1141        ).cuda()
1142        c1, c2 = model.c1, model.c2
1143        x = torch.zeros(2, 100, device="cuda")
1144
1145        # float16 on one submodule and float32 on everything else
1146        model.c2 = FSDP(model.c2, mixed_precision=float16)
1147        fsdp = FSDP(model)
1148
1149        fsdp(x).sum().backward()
1150
1151        self.assertEqual(forward_inputs[model].dtype, torch.float32)
1152        self.assertEqual(forward_inputs[c1].dtype, torch.float32)
1153        self.assertEqual(forward_inputs[c2].dtype, torch.float16)
1154
1155    @skip_if_lt_x_gpu(2)
1156    def test_float16_on_one_submodule_skip_inputs(self):
1157        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
1158        float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
1159
1160        model = SaveForwardInputsModel(
1161            forward_inputs=forward_inputs, cast_forward_inputs=True
1162        ).cuda()
1163        c1, c2 = model.c1, model.c2
1164        x = torch.zeros(2, 100, device="cuda")
1165
1166        # float16 on one submodule and float32 on everything else
1167        model.c2 = FSDP(model.c2, mixed_precision=float16)
1168        fsdp = FSDP(model)
1169
1170        fsdp(x).sum().backward()
1171
1172        self.assertEqual(forward_inputs[model].dtype, torch.float32)
1173        self.assertEqual(forward_inputs[c1].dtype, torch.float32)
1174        self.assertEqual(forward_inputs[c2].dtype, torch.float32)
1175
1176    @skip_if_lt_x_gpu(2)
1177    def test_float16_on_one_submodule_skip_inputs_error(self):
1178        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
1179        float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=False)
1180
1181        model = SaveForwardInputsModel(
1182            forward_inputs=forward_inputs, cast_forward_inputs=False
1183        ).cuda()
1184        c1, c2 = model.c1, model.c2
1185        x = torch.zeros(2, 100, device="cuda")
1186
1187        # float16 on one submodule and float32 on everything else
1188        model.c2 = FSDP(model.c2, mixed_precision=float16)
1189        fsdp = FSDP(model)
1190
1191        with self.assertRaisesRegex(
1192            RuntimeError, "mat1 and mat2 must have the same dtype"
1193        ):
1194            fsdp(x).sum().backward()
1195
1196    @skip_if_lt_x_gpu(2)
1197    def test_submodules_with_different_precisions_error(self):
1198        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
1199        float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
1200        float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
1201
1202        model = SaveForwardInputsModel(
1203            forward_inputs=forward_inputs, cast_forward_inputs=False
1204        ).cuda()
1205        x = torch.zeros(2, 100, device="cuda")
1206
1207        # For submodules with different precisions, right now current design
1208        # does not support the case when the root FSDP instance wraps a submodule
1209        # that is not the first one executed. Because for that submodule, its inputs
1210        # (or previous submodule's outputs) have no way to be casted, instead,
1211        # the root module's inputs are casted upfront before entering
1212        # root module's forward
1213        model.c1 = FSDP(model.c1, mixed_precision=float16)
1214        fsdp = FSDP(model, mixed_precision=float32)
1215        with self.assertRaisesRegex(
1216            RuntimeError, "mat1 and mat2 must have the same dtype"
1217        ):
1218            fsdp(x).sum().backward()
1219
1220    @skip_if_lt_x_gpu(2)
1221    def test_submodules_with_different_precisions(self):
1222        forward_inputs: Dict[nn.Module, torch.Tensor] = {}
1223        float16 = MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)
1224        float32 = MixedPrecision(param_dtype=torch.float32, cast_forward_inputs=True)
1225
1226        model = SaveForwardInputsModel(
1227            forward_inputs=forward_inputs, cast_forward_inputs=False
1228        ).cuda()
1229        c1, c2 = model.c1, model.c2
1230        x = torch.zeros(2, 100, device="cuda")
1231
1232        model.c2 = FSDP(model.c2, mixed_precision=float16)
1233        fsdp = FSDP(model, mixed_precision=float32)
1234
1235        fsdp(x).sum().backward()
1236
1237        self.assertEqual(forward_inputs[model].dtype, torch.float32)
1238        self.assertEqual(forward_inputs[c1].dtype, torch.float32)
1239        self.assertEqual(forward_inputs[c2].dtype, torch.float16)
1240
1241    @skip_if_lt_x_gpu(2)
1242    def test_submodules_with_external_inputs(self):
1243        class ToyModule(nn.Module):
1244            def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
1245                super().__init__()
1246                self.l = nn.Linear(100, 100)
1247                self.forward_inputs = forward_inputs
1248
1249            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1250                self.forward_inputs["l2_input_x"] = x
1251                self.forward_inputs["l2_input_y"] = y
1252                return self.l(x)
1253
1254        class ToyModel(nn.Module):
1255            def __init__(self, forward_inputs: Dict[str, torch.Tensor]) -> None:
1256                super().__init__()
1257                self.l1 = nn.Linear(100, 100)
1258                self.l2 = ToyModule(forward_inputs)
1259                self.forward_inputs = forward_inputs
1260
1261            def forward(self, x: torch.Tensor) -> torch.Tensor:
1262                self.forward_inputs["model_input_x"] = x
1263                y = torch.ones(2, 100, device="cuda", dtype=torch.float32)
1264                return self.l2(self.l1(x), y)
1265
1266        forward_inputs: Dict[str, torch.Tensor] = {}
1267
1268        float16 = MixedPrecision(param_dtype=torch.float16)
1269        model = ToyModel(forward_inputs).cuda()
1270        x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
1271        model.l2 = FSDP(model.l2, mixed_precision=float16)
1272        fsdp = FSDP(model, mixed_precision=float16)
1273
1274        fsdp(x).sum().backward()
1275
1276        # Inputs are casted in root module in default, inputs of submodules are not
1277        # explicitly casted, so the external inputs ``y`` of module ``self.l2`` is
1278        # not casted.
1279        self.assertEqual(forward_inputs["model_input_x"].dtype, torch.float16)
1280        self.assertEqual(forward_inputs["l2_input_x"].dtype, torch.float16)
1281        self.assertEqual(forward_inputs["l2_input_y"].dtype, torch.float32)
1282
1283
1284class TestFSDPTrainEval(FSDPTest):
1285    @property
1286    def world_size(self):
1287        return 2
1288
1289    @skip_if_lt_x_gpu(2)
1290    def test_train_ema_eval_flow(self):
1291        """
1292        Tests a train -> EMA update -> eval flow with mixed precision enabled.
1293        """
1294        self.run_subtests(
1295            {
1296                "sharding_strategy": [
1297                    # We mainly want to test `SHARD_GRAD_OP` since it surfaced
1298                    # the original bug of not using the right EMA parameters
1299                    # for eval, but we also test the others for completeness
1300                    ShardingStrategy.SHARD_GRAD_OP,
1301                    ShardingStrategy.FULL_SHARD,
1302                    ShardingStrategy.NO_SHARD,
1303                ]
1304            },
1305            self._test_train_ema_eval_flow,
1306        )
1307
1308    def _test_train_ema_eval_flow(self, sharding_strategy: ShardingStrategy):
1309        class TransformerWithEMA(nn.Module):
1310            def __init__(self, device: torch.device):
1311                super().__init__()
1312                self.module = nn.Transformer(device=device)
1313                self.ema_module = AveragedModel(
1314                    nn.Transformer(device=device),
1315                    multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(),
1316                    use_buffers=True,
1317                )
1318
1319            def forward(self, *args, **kwargs):
1320                # Use main copy for training and EMA copy for eval
1321                if self.training:
1322                    return self.module(*args, **kwargs)
1323                return self.ema_module(*args, **kwargs)
1324
1325        device = torch.device("cuda")
1326        model = TransformerWithEMA(device=device)
1327        policy = ModuleWrapPolicy(
1328            {nn.Transformer, nn.TransformerEncoderLayer, nn.TransformerDecoderLayer}
1329        )
1330        mixed_precision = MixedPrecision(param_dtype=torch.float16)
1331        fsdp_model = FSDP(
1332            model,
1333            auto_wrap_policy=policy,
1334            mixed_precision=mixed_precision,
1335            sharding_strategy=sharding_strategy,
1336        )
1337        optim = torch.optim.Adam(fsdp_model.module.parameters(), lr=1e-2)
1338        if self.rank == 0:
1339            print(fsdp_model)
1340        torch.manual_seed(1 + self.rank)
1341        eval_src = torch.randn((8, 1, 512), device=device)
1342        eval_tgt = torch.randn((16, 1, 512), device=device)
1343        eval_out_sums: List[torch.Tensor] = []
1344        # An iteration consists of training forward/backward/optimizer,
1345        # updating the EMA copy with the main copy, and eval forward
1346        for _ in range(3):
1347            fsdp_model.train()
1348            train_src = torch.randn((8, 4, 512), device=device)
1349            train_tgt = torch.randn((16, 4, 512), device=device)
1350            train_out = fsdp_model(train_src, train_tgt)
1351            train_out.sum().backward()
1352            optim.step()
1353            optim.zero_grad()
1354            with FSDP.summon_full_params(fsdp_model):
1355                fsdp_model.ema_module.update_parameters(fsdp_model.module)
1356            fsdp_model.eval()
1357            with torch.no_grad():
1358                eval_out = fsdp_model(eval_src, eval_tgt)
1359            eval_out_sums.append(eval_out.sum())
1360        # Check that the eval outputs differ from iteration to iteration as a
1361        # proxy for eval using the correct EMA parameters
1362        for i in range(len(eval_out_sums) - 1):
1363            self.assertNotEqual(eval_out_sums[i], eval_out_sums[i + 1])
1364        self.assertNotEqual(eval_out_sums[0], eval_out_sums[-1])
1365
1366
1367if __name__ == "__main__":
1368    run_tests()
1369