xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import functools
4import os
5import sys
6import warnings
7from collections import namedtuple
8from contextlib import nullcontext
9from copy import deepcopy
10from itertools import chain
11from typing import Any, Tuple
12
13import torch
14import torch.distributed as dist
15import torch.distributed.fsdp._traversal_utils as traversal_utils
16import torch.nn as nn
17from torch.distributed.fsdp import (
18    CPUOffload,
19    FlatParameter,
20    FullyShardedDataParallel as FSDP,
21    ShardingStrategy,
22)
23from torch.distributed.fsdp._flat_param import _FSDP_USE_UNSAFE_SETATTR
24from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES
25from torch.distributed.fsdp.wrap import (
26    always_wrap_policy,
27    ModuleWrapPolicy,
28    transformer_auto_wrap_policy,
29)
30from torch.distributed.optim import _apply_optimizer_in_backward
31from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
32from torch.nn.parallel import DistributedDataParallel as DDP
33from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
34from torch.testing._internal.common_fsdp import (
35    _assert_module_states,
36    CUDAInitMode,
37    FSDPInitMode,
38    FSDPTest,
39    FSDPTestMultiThread,
40    MLP,
41    NestedWrappedModule,
42    TransformerWithSharedParams,
43)
44from torch.testing._internal.common_utils import (
45    instantiate_parametrized_tests,
46    parametrize,
47    run_tests,
48    TEST_WITH_DEV_DBG_ASAN,
49)
50
51
52if not dist.is_available():
53    print("Distributed not available, skipping tests", file=sys.stderr)
54    sys.exit(0)
55
56if TEST_WITH_DEV_DBG_ASAN:
57    print(
58        "Skip dev-asan as torch + multiprocessing spawn have known issues",
59        file=sys.stderr,
60    )
61    sys.exit(0)
62
63
64class MyModel(nn.Module):
65    def __init__(self) -> None:
66        super().__init__()
67        self.a = nn.Linear(2, 2)
68        self.b = nn.Linear(2, 2)
69
70    def forward(self, x, y):
71        return self.b(self.a(x + y))
72
73
74class TestFSDPMiscMultiProcess(FSDPTest):
75    @property
76    def world_size(self):
77        return 2
78
79    @property
80    def process_group(self):
81        return dist.distributed_c10d._get_default_group()
82
83    @skip_if_lt_x_gpu(2)
84    @parametrize("use_index", [True, False])
85    def test_fsdp_device_id(self, use_index):
86        """
87        Tests the FSDP ``device_id`` argument:
88          - Wrapping a CPU module should move the module to the GPU matching
89          ``device_id``
90          - Wrapping a GPU module already on the GPU matching ``device_id``
91          should not raise an error
92          - Wrapping a GPU module already on GPU and passing a GPU device
93          without specifying a device ID (i.e. ``torch.device("cuda")``) warns
94        """
95        dev_id = (
96            torch.cuda.current_device()
97            if use_index
98            else torch.device("cuda", torch.cuda.current_device())
99        )
100
101        def _check_device_matches(module, device_id):
102            """Checks that the ``FlatParameter``s in ``module`` have device
103            matching ``device_id``."""
104            devices = {
105                p.device for p in module.parameters() if isinstance(p, FlatParameter)
106            }
107            assert len(devices) > 0
108            self.assertEqual(1, len(devices))
109            found_device = devices.pop()
110            if use_index and not isinstance(device_id, torch.device):
111                device = torch.device("cuda", device_id)
112            else:
113                device = device_id
114            self.assertEqual(found_device, device)
115
116        # Check that FSDP parameters are moved to `device_id` for a CPU module
117        nested_wrapped_module = NestedWrappedModule.init(
118            self.process_group,
119            FSDPInitMode.RECURSIVE,
120            CUDAInitMode.CUDA_NEVER,
121            fsdp_kwargs={"device_id": dev_id},
122        )
123        _check_device_matches(nested_wrapped_module, dev_id)
124        # Check that specifying `device_id` for a GPU module already on that
125        # device does not raise an error
126        nested_wrapped_module = NestedWrappedModule.init(
127            self.process_group,
128            FSDPInitMode.RECURSIVE,
129            CUDAInitMode.CUDA_BEFORE,
130            fsdp_kwargs={"device_id": dev_id},
131        )
132        _check_device_matches(nested_wrapped_module, dev_id)
133        # Check that passing in `torch.device("cuda")` for a GPU module warns
134        regex = "does not have an explicit index"
135        context = self.assertWarnsRegex(
136            expected_warning=UserWarning, expected_regex=regex
137        )
138        with context:
139            nested_wrapped_module = NestedWrappedModule.init(
140                self.process_group,
141                FSDPInitMode.RECURSIVE,
142                CUDAInitMode.CUDA_BEFORE,
143                fsdp_kwargs={"device_id": torch.device("cuda")},
144            )
145        _check_device_matches(
146            nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
147        )
148
149    @skip_if_lt_x_gpu(2)
150    def test_fsdp_zero2_eval_with_prefetch(self):
151        # Test FSDP validation with SHARD_GRAD_OP and forward_prefetch
152
153        class Mnist(nn.Module):
154            def __init__(self) -> None:
155                super().__init__()
156                self.conv1 = nn.Conv2d(1, 32, 3, 1)
157                self.conv2 = nn.Conv2d(32, 64, 3, 1)
158                self.dropout1 = nn.Dropout(0.25)
159                self.dropout2 = nn.Dropout(0.5)
160                self.fc1 = nn.Linear(9216, 128)
161                self.fc2 = nn.Linear(128, 10)
162                self.ln = nn.LayerNorm(9216)
163
164            def forward(self, x, y):
165                x = self.conv1(x)
166                x = torch.nn.functional.relu(x)
167                x = self.conv2(x)
168                x = torch.nn.functional.relu(x)
169                x = torch.nn.functional.max_pool2d(x, 2)
170                x = self.dropout1(x)
171                x = torch.flatten(x, 1)
172                x = self.ln(x)
173                x = self.fc1(x)
174                x = torch.nn.functional.relu(x)
175                x = self.dropout2(x)
176                x = self.fc2(x)
177                output = torch.nn.functional.log_softmax(x, dim=1)
178                loss = torch.nn.functional.cross_entropy(output, y)
179                return loss
180
181        model = Mnist().cuda()
182        model1 = Mnist().cuda()
183        model1.load_state_dict(model.state_dict())
184        fsdp_model = FSDP(
185            model,
186            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
187            forward_prefetch=True,
188            use_orig_params=True,
189            auto_wrap_policy=ModuleWrapPolicy([nn.Linear, nn.Conv2d]),
190        )
191        ddp_model = torch.nn.parallel.DistributedDataParallel(
192            model1,
193        )
194
195        fsdp_opt = torch.optim.SGD(fsdp_model.parameters(), lr=1e-4)
196        ddp_opt = torch.optim.SGD(ddp_model.parameters(), lr=1e-4)
197
198        seed = self.rank + 20231010
199        torch.manual_seed(seed)
200        torch.cuda.manual_seed(seed)
201
202        losses = []
203        grads = []
204        for i in range(5):
205            x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
206            y = torch.randint(low=0, high=9, size=(8,), device="cuda")
207            for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
208                seed = self.rank + i
209                torch.manual_seed(seed)
210                torch.cuda.manual_seed(seed)
211                loss = model(x, y).sum()
212                losses.append(loss)
213                loss.backward()
214                opt.step()
215                grads.append(x.grad)
216                opt.zero_grad()
217            assert torch.allclose(losses[0], losses[1])
218            assert torch.allclose(grads[0], grads[1])
219            losses.clear()
220            grads.clear()
221
222        with torch.no_grad():
223            fsdp_model.eval()
224            ddp_model.eval()
225            for _ in range(5):
226                x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
227                y = torch.randint(low=0, high=9, size=(8,), device="cuda")
228                fsdp_loss = fsdp_model(x, y)
229                ddp_loss = ddp_model(x, y)
230                assert torch.allclose(fsdp_loss, ddp_loss)
231
232        fsdp_model.train()
233        ddp_model.train()
234        for i in range(5):
235            x = torch.randn(8, 1, 28, 28, device="cuda").requires_grad_()
236            y = torch.randint(low=0, high=9, size=(8,), device="cuda")
237            for model, opt in ((fsdp_model, fsdp_opt), (ddp_model, ddp_opt)):
238                seed = self.rank + i
239                torch.manual_seed(seed)
240                torch.cuda.manual_seed(seed)
241                loss = model(x, y).sum()
242                losses.append(loss)
243                loss.backward()
244                opt.step()
245                grads.append(x.grad)
246                opt.zero_grad()
247            assert torch.allclose(losses[0], losses[1])
248            assert torch.allclose(grads[0], grads[1])
249            losses.clear()
250            grads.clear()
251
252    @skip_if_lt_x_gpu(2)
253    @parametrize("use_second_layer", [True, False])
254    @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
255    def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
256        # When use_second_layer=True, b is involved in forward computation but does
257        # not receive grad in backward. Otherwise, b is not involved in forward
258        # computation.
259
260        class MyModel(nn.Module):
261            def __init__(self) -> None:
262                super().__init__()
263                self.a = nn.Linear(10, 10)
264                self.b = nn.Linear(10, 10)
265
266            def forward(self, x, y):
267                out1 = self.a(x)
268                if use_second_layer:
269                    out2 = self.b(y)
270                    return out1, out2
271                else:
272                    return out1
273
274        fsdp = FSDP(
275            MyModel().cuda(),
276            sharding_strategy=sharding_strategy,
277            auto_wrap_policy=always_wrap_policy,
278        )
279        x = torch.randn(10, 10, device="cuda")
280        y = torch.randn(10, 10, device="cuda")
281        for i in range(4):
282            if use_second_layer:
283                a, b = fsdp(x, y)
284            else:
285                a = fsdp(x, y)
286            loss = a.sum()
287            loss.backward()
288
289            # self.a receives grad, self.b does not
290            a_grad = fsdp.module.a._handle.flat_param.grad
291            b_grad = fsdp.module.b._handle.flat_param.grad
292            self.assertIsNotNone(a_grad)
293            self.assertIsNone(b_grad)
294
295    @skip_if_lt_x_gpu(2)
296    def test_fsdp_not_all_outputs_used_in_loss(self):
297        self.run_subtests(
298            {
299                "sharding_strategy": [
300                    ShardingStrategy.FULL_SHARD,
301                    ShardingStrategy.SHARD_GRAD_OP,
302                    ShardingStrategy.NO_SHARD,
303                ]
304            },
305            self._test_fsdp_not_all_outputs_used_in_loss,
306        )
307
308    def _test_fsdp_not_all_outputs_used_in_loss(
309        self, sharding_strategy: ShardingStrategy
310    ):
311        class MyModule(nn.Module):
312            def __init__(self) -> None:
313                super().__init__()
314                self.lin1 = nn.Linear(4, 4)
315                self.lin2 = nn.Linear(4, 4)
316
317            def forward(self, x):
318                a = self.lin1(x)
319                b = self.lin2(x)
320                return (a, b)
321
322        def _check_resharded(fsdp_module):
323            handle = fsdp_module._handle
324            if not handle:
325                return
326            param = handle.flat_param
327            if handle.uses_sharded_strategy:
328                full_param = param._full_param_padded
329                self.assertEqual(full_param.storage().size(), 0)
330
331            self.assertEqual(param.data_ptr(), param._local_shard.data_ptr())
332
333        def _check_equal(local, fsdp):
334            with FSDP.summon_full_params(fsdp):
335                for p1, p2 in zip(fsdp.parameters(), local.parameters()):
336                    torch.testing.assert_close(p1, p2)
337
338        fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
339        m = MyModule().cuda()
340        m_local = deepcopy(m)
341        local_m = m_local
342        prev_params = [p.clone() for p in m_local.parameters()]
343
344        m.lin1 = fsdp_ctor(m.lin1)
345        m = fsdp_ctor(m)
346        _check_equal(m_local, m)
347
348        opt = torch.optim.SGD(m.parameters(), lr=1e-3)
349        opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
350
351        for i in range(6):
352            t = torch.ones(4, device="cuda")
353            a, b = m(t)
354            local_a, local_b = local_m(t)
355            if i < 2:
356                # use both params in loss computation. Later,
357                # b will go unused and we check grads are the
358                # same as local training.
359                loss = (a @ b).sum()
360                loss_local = (local_a @ local_b).sum()
361            else:
362                loss = a.sum()
363                loss_local = local_a.sum()
364
365            loss.backward()
366            loss_local.backward()
367            _check_resharded(m)
368            opt.step()
369            opt_local.step()
370            _check_equal(m_local, m)
371            # Ensure at least some change from previous params, otherwise
372            # above check would be vacuously true.
373            self.assertTrue(
374                any(
375                    not torch.equal(p1, p2)
376                    for p1, p2 in zip(prev_params, m_local.parameters())
377                )
378            )
379            prev_params = [p.clone() for p in local_m.parameters()]
380            opt.zero_grad()
381            opt_local.zero_grad()
382
383        dist.barrier()
384
385    @skip_if_lt_x_gpu(2)
386    def test_fsdp_optim_overlap_no_use_orig_params_error(self):
387        fsdp_overlap = FSDP(
388            MyModel().cuda(),
389            auto_wrap_policy=always_wrap_policy,
390            use_orig_params=False,
391        )
392        optim_cls = torch.optim.SGD
393        optim_kwargs = {"lr": 0.03}
394        _apply_optimizer_in_backward(
395            optimizer_class=optim_cls,
396            params=fsdp_overlap.parameters(),
397            optimizer_kwargs=optim_kwargs,
398            register_hook=False,
399        )
400
401        inp = torch.randn(10, 10, device="cuda")
402        with self.assertRaisesRegex(
403            RuntimeError, "only supported with use_orig_params=True"
404        ):
405            fsdp_overlap(inp, inp)
406
407    @skip_if_lt_x_gpu(2)
408    def test_fsdp_optimizer_overlap(self):
409        torch.manual_seed(0)
410        for cpu_offload in [True, False]:
411            offload = CPUOffload(offload_params=cpu_offload)
412            model = MyModel().cuda()
413            model_overlap = deepcopy(model)
414            fsdp = FSDP(
415                model.cuda(),
416                auto_wrap_policy=always_wrap_policy,
417                use_orig_params=True,
418                cpu_offload=offload,
419            )
420            fsdp_overlap = FSDP(
421                model_overlap.cuda(),
422                auto_wrap_policy=always_wrap_policy,
423                use_orig_params=True,
424                cpu_offload=offload,
425            )
426            optim_cls = torch.optim.SGD
427            optim_kwargs = {"lr": 0.03}
428            _apply_optimizer_in_backward(
429                optimizer_class=optim_cls,
430                params=fsdp_overlap.parameters(),
431                optimizer_kwargs=optim_kwargs,
432                register_hook=False,
433            )
434            for p in fsdp_overlap.parameters():
435                assert hasattr(p, "_in_backward_optimizers")
436            optim = optim_cls(fsdp.parameters(), **optim_kwargs)
437
438            # Verify params initially equal
439            for p1, p2 in zip(fsdp.parameters(), fsdp_overlap.parameters()):
440                self.assertEqual(p1, p2)
441
442            with FSDP.summon_full_params(fsdp_overlap):
443                fsdp_overlap_prev_params = [
444                    (n, p.clone()) for n, p in fsdp_overlap.named_parameters()
445                ]
446
447            for i in range(6):
448                inp = torch.randn(2, 2, device="cuda")
449                with torch.no_grad():
450                    inp_clone = inp.clone()
451                fsdp(inp, inp).sum().backward()
452                fsdp_overlap(inp_clone, inp_clone).sum().backward()
453
454                optim.step()
455                optim.zero_grad()
456
457                # Overlapped optimizer FSDP module should have sharded_grad as None.
458                for fsdp_unit in FSDP.fsdp_modules(fsdp_overlap):
459                    handle = fsdp_unit._handle
460                    if handle:
461                        handle_grad = handle.sharded_grad
462                        self.assertEqual(
463                            None,
464                            handle_grad,
465                            "Overlapped FSDP sharded_grad is not None!",
466                        )
467
468                # Note: FSDP without optimizer overlap won't set sharded_grad to None until the next
469                # pre-forward since it needs to run FSDP specific logic that picks up that set_to_none=True
470                # has been called (or that the gradients have been otherwise set to None)
471
472                # Verify parameters are different than prev iteration
473                with FSDP.summon_full_params(fsdp_overlap, with_grads=True):
474                    for (n, p), (n_prev, p_prev) in zip(
475                        fsdp_overlap.named_parameters(), fsdp_overlap_prev_params
476                    ):
477                        self.assertNotEqual(
478                            p,
479                            p_prev,
480                            f"{n_prev} Params at iter {i} same as previous iter!",
481                        )
482
483                # Verify overlap and non overlapped are the same
484                with FSDP.summon_full_params(fsdp_overlap):
485                    with FSDP.summon_full_params(fsdp):
486                        for (n_overlap, p_overlap), (n, p) in zip(
487                            fsdp_overlap.named_parameters(), fsdp.named_parameters()
488                        ):
489                            self.assertEqual(n_overlap, n)
490                            self.assertEqual(
491                                p,
492                                p_overlap,
493                                f"Rank {self.rank}: Params not equal at iteration {i}: {n_overlap} - {p} vs {p_overlap}",
494                            )
495                            self.assertEqual(
496                                None, p.grad, f"Expected param {n} grad to be None"
497                            )
498                            self.assertEqual(
499                                None,
500                                p_overlap.grad,
501                                f"Expected param {n_overlap} grad to be None",
502                            )
503
504                    fsdp_overlap_prev_params = [
505                        (n, p.clone()) for n, p in fsdp_overlap.named_parameters()
506                    ]
507
508    @skip_if_lt_x_gpu(2)
509    def test_fsdp_cpu_training(self):
510        """Tests FSDP training on CPU."""
511        gloo_pg = dist.new_group(backend="gloo")
512        for ss in [
513            ShardingStrategy.NO_SHARD,
514            ShardingStrategy.FULL_SHARD,
515            ShardingStrategy.SHARD_GRAD_OP,
516            ShardingStrategy.HYBRID_SHARD,
517            ShardingStrategy._HYBRID_SHARD_ZERO2,
518        ]:
519            torch.manual_seed(42)
520            model = MyModel()
521            ref_model = DDP(deepcopy(model), process_group=gloo_pg)
522            model = FSDP(
523                model,
524                auto_wrap_policy=always_wrap_policy,
525                process_group=gloo_pg,
526                device_id=torch.device("cpu"),
527            )
528            ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
529            optim = torch.optim.Adam(model.parameters(), lr=1e-2)
530            torch.manual_seed(42 + self.rank)
531            inp = torch.randn(2, 2)
532            for _ in range(10):
533                losses = []
534                for _model, _optim in ((ref_model, ref_optim), (model, optim)):
535                    loss = _model(inp, inp).sum()
536                    losses.append(loss)
537                    loss.backward()
538                    _optim.step()
539                    _optim.zero_grad()
540                self.assertEqual(losses[0], losses[1])
541
542    @skip_if_lt_x_gpu(2)
543    def test_fsdp_cpu_init_stays_on_cpu(self):
544        # Move me to MT test once warning logging and backward collective issue
545        # is resolved.
546        """Tests that passing a CPU module to FSDP preserves that the wrapped
547        module is on CPU after FSDP initialization, albeit after logging a
548        warning, and that FSDP moves CPU input to GPU before the forward."""
549        torch.cuda.set_device(self.rank)
550        regex = "passed-in `module` is on CPU"
551        context = self.assertWarnsRegex(
552            expected_warning=UserWarning, expected_regex=regex
553        )
554        with context:
555            nested_wrapped_module = NestedWrappedModule.init(
556                self.process_group,
557                FSDPInitMode.RECURSIVE,
558                CUDAInitMode.CUDA_NEVER,
559            )
560            fsdp_model = FSDP(nested_wrapped_module, self.process_group)
561        devices = {p.device for p in fsdp_model.parameters()}
562        self.assertEqual(1, len(devices))
563        self.assertEqual(torch.device("cpu"), devices.pop())
564        fsdp_model = fsdp_model.cuda()
565        # Ensure fwd + backward can be performed after moving to CUDA.
566        # CPU input also tests that input is correctly moved to appropriate
567        # CUDA device.
568        inp = fsdp_model.module.get_input(device=torch.device("cpu"))
569        fsdp_model(*inp).sum().backward()
570
571    @skip_if_lt_x_gpu(2)
572    def test_cpu_init_with_sync_module_states(self):
573        """
574        Tests that passing ``sync_module_states=True`` raises an error for
575        a CPU module since the synchronization requires GPU communication,
576        while additionally passing ``device_id`` does not raise an error, even
577        when the model has CPU buffers.
578        """
579
580        def init_nested_wrapped_module():
581            return NestedWrappedModule.init(
582                self.process_group,
583                FSDPInitMode.NO_FSDP,
584                CUDAInitMode.CUDA_NEVER,
585            )
586
587        with self.assertRaisesRegex(
588            ValueError,
589            "The module has CPU parameters or buffers when `sync_module_states=True`",
590        ):
591            FSDP(
592                init_nested_wrapped_module(),
593                self.process_group,
594                sync_module_states=True,
595            )
596
597        # Check that `device_id` with `sync_module_states=True` works
598        nested_wrapped_module = init_nested_wrapped_module()
599        nested_wrapped_module.buf = nn.Buffer(
600            torch.ones((2, 2), device="cpu") * self.rank
601        )
602        nested_wrapped_module.module[0].buf = nn.Buffer(
603            torch.ones((3, 2), device="cpu") * self.rank
604        )
605        nested_wrapped_module = FSDP(
606            nested_wrapped_module,
607            self.process_group,
608            auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
609            device_id=torch.cuda.current_device(),
610            sync_module_states=True,
611        )
612        # Each rank's buffers should be 0s since rank 0 is the source, and they
613        # should be on GPU since we specified `device_id`
614        self.assertEqual(
615            nested_wrapped_module.buf.device,
616            torch.device("cuda", torch.cuda.current_device()),
617        )
618        self.assertEqual(nested_wrapped_module.buf, torch.zeros((2, 2)))
619        self.assertEqual(
620            nested_wrapped_module.module.module[0].buf.device,
621            torch.device("cuda", torch.cuda.current_device()),
622        )
623        self.assertEqual(
624            nested_wrapped_module.module.module[0].buf, torch.zeros((3, 2))
625        )
626
627
628class TestFSDPMiscMultiThread(FSDPTestMultiThread):
629    @property
630    def world_size(self):
631        return 2
632
633    @property
634    def process_group(self):
635        return dist.distributed_c10d._get_default_group()
636
637    @skip_if_lt_x_gpu(2)
638    def test_fsdp_namedtuple(self):
639        class MyModule(nn.Module):
640            def __init__(self) -> None:
641                super().__init__()
642                self.lin = nn.Linear(100, 100)
643
644            def forward(self, x):
645                return x
646
647        m = MyModule().cuda()
648        m = FSDP(m)
649        t = torch.ones(1, device="cuda", requires_grad=True)
650
651        MyOutputType = namedtuple(
652            "MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
653        )
654
655        inp = MyOutputType()
656        out = m(inp)
657        # Ensure hooks are registered
658        for x in out:
659            self.assertNotEqual([], list(x._backward_hooks.values()))
660
661        # TODO: we should check backward() and param is resharded
662        # as well, but this is blocked by
663        # https://github.com/pytorch/pytorch/issues/83107 and
664        # https://github.com/pytorch/pytorch/issues/83129
665
666    @skip_if_lt_x_gpu(2)
667    def test_device_id_auto_wrap(self):
668        """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
669        nested FSDP instances."""
670        self.run_subtests(
671            {"use_callable": [False, True]},
672            self._test_device_id_auto_wrap,
673        )
674
675    def _test_device_id_auto_wrap(self, use_callable: bool):
676        module_classes = {TransformerEncoderLayer, TransformerDecoderLayer}
677        if use_callable:
678            auto_wrap_policy = functools.partial(
679                transformer_auto_wrap_policy,
680                transformer_layer_cls=module_classes,
681            )
682        else:
683            auto_wrap_policy = ModuleWrapPolicy(module_classes)
684        fsdp_kwargs = {
685            "auto_wrap_policy": auto_wrap_policy,
686            "device_id": torch.cuda.current_device(),
687        }
688        fsdp_model = TransformerWithSharedParams.init(
689            self.process_group,
690            FSDPInitMode.RECURSIVE,
691            CUDAInitMode.CUDA_BEFORE,
692            fsdp_kwargs,
693        )
694        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
695            self.assertEqual(
696                fsdp_module.compute_device,
697                torch.device("cuda", torch.cuda.current_device()),
698            )
699
700    @skip_if_lt_x_gpu(2)
701    def test_fsdp_device_id_cpu_offload(self):
702        """
703        Tests FSDP when specifying both ``device_id`` and parameter CPU
704        offloading.
705        """
706        self.run_subtests(
707            {"use_orig_params": [False, True]},
708            self._test_fsdp_device_id_cpu_offload,
709        )
710
711    def _test_fsdp_device_id_cpu_offload(self, use_orig_params: bool):
712        class MyModel(nn.Module):
713            def __init__(self) -> None:
714                super().__init__()
715                self.seq = nn.Sequential(
716                    nn.Linear(10, 10),
717                    nn.Linear(10, 10),
718                )
719                self.lin = nn.Linear(10, 10)
720
721            def forward(self, x):
722                return self.lin(self.seq(x))
723
724        model = MyModel()
725        # Choose a wrapping policy such that there are (1) nested FSDP
726        # instances and (2) the parent FSDP instance has managed parameters
727        auto_wrap_policy = ModuleWrapPolicy({nn.Sequential})
728        fsdp_model = FSDP(
729            model,
730            auto_wrap_policy=auto_wrap_policy,
731            cpu_offload=CPUOffload(offload_params=True),
732            device_id=torch.cuda.current_device(),
733            use_orig_params=use_orig_params,
734        )
735        cpu_device = torch.device("cpu")
736        for handle in traversal_utils._get_fsdp_handles(fsdp_model):
737            self.assertEqual(handle.flat_param.device, cpu_device)
738
739    @skip_if_lt_x_gpu(2)
740    def test_module_device_mismatches_device_id(self):
741        """Tests that specifying a ``device_id`` argument to FSDP for a GPU
742        module that does not match the GPU device ID raises an error."""
743        # TODO: override FSDP MT Thread _run to set this instead of here for
744        # every test.
745        torch.cuda.set_device(self.rank)
746        context = (
747            self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
748            if self.rank != 0
749            else nullcontext()
750        )
751        with context:
752            NestedWrappedModule.init(
753                self.process_group,
754                FSDPInitMode.RECURSIVE,
755                # Move wrapped modules to CUDA before wrapping with FSDP
756                cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
757                # Should raise error since rank 1 is given `device_id=0` when
758                # the model is on cuda:1
759                fsdp_kwargs={"device_id": 0},
760            )
761
762    @skip_if_lt_x_gpu(2)
763    def test_cpu_gpu_module(self):
764        """Tests a CPU + GPU module supported if device_id is passed
765        in, errors if device_id is not.
766        """
767        torch.cuda.set_device(self.rank)
768
769        class CPUGPUModule(nn.Module):
770            def __init__(self) -> None:
771                super().__init__()
772                self.a = nn.Linear(1, 1).cuda()
773                self.b = nn.Linear(1, 1)
774
775        cpu_gpu = CPUGPUModule()
776        fsdp = FSDP(cpu_gpu, device_id=torch.cuda.current_device())
777        for param in fsdp.parameters():
778            self.assertEqual(param.device, torch.device(torch.cuda.current_device()))
779
780        # without device_id, we hit an error
781        with self.assertRaisesRegex(RuntimeError, "please pass in device_id"):
782            FSDP(CPUGPUModule())
783
784    @skip_if_lt_x_gpu(2)
785    def test_fsdp_ignored_module_meta(self):
786        torch.cuda.set_device(self.rank)
787
788        class CPUGPUModule(nn.Module):
789            def __init__(self) -> None:
790                super().__init__()
791                self.a = nn.Linear(1, 1)
792                self.b = nn.Linear(1, 1)
793
794        with torch.device("meta"):
795            m = CPUGPUModule()
796        m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True)
797        meta_device = torch.device("meta")
798        self.assertEqual(meta_device, next(m.a.parameters()).device)
799
800        # Test with param_init_fn
801        with torch.device("meta"):
802            m = CPUGPUModule()
803        m = FSDP(
804            m,
805            device_id=torch.cuda.current_device(),
806            ignored_modules=[m.a],
807            use_orig_params=True,
808            param_init_fn=lambda m: m.to_empty(
809                device=torch.cuda.current_device(), recurse=False
810            ),
811        )
812        self.assertEqual(meta_device, next(m.a.parameters()).device)
813
814    @skip_if_lt_x_gpu(2)
815    def test_fsdp_device_id_no_move_ignored_params_and_bufs(self):
816        class CPUGPUModule(nn.Module):
817            def __init__(self) -> None:
818                super().__init__()
819                self.a = nn.Linear(1, 1)
820                self.b = nn.Linear(1, 1)
821                self.a.buf = torch.nn.Buffer(torch.ones(1))
822
823        m = CPUGPUModule()
824        m = FSDP(m, device_id=self.rank, ignored_modules=[m.a], use_orig_params=True)
825        ignored_params = m.a.parameters()
826        ignored_bufs = m.a.buffers()
827        for t in chain(ignored_params, ignored_bufs):
828            self.assertEqual(torch.device("cpu"), t.device)
829
830    @skip_if_lt_x_gpu(2)
831    def test_multigpu_module(self):
832        """
833        Module on multiple GPUs wrapped in FSDP should raise an error.
834        """
835
836        class MultiGPUModule(nn.Module):
837            def __init__(self, rank):
838                super().__init__()
839                self.rank = rank
840                self.a = nn.Linear(1, 1).cuda(self.rank)
841                self.b = nn.Linear(1, 1).cuda((self.rank + 1) % dist.get_world_size())
842
843        with self.assertRaisesRegex(
844            RuntimeError, "FSDP only supports single device modules"
845        ):
846            FSDP(MultiGPUModule(self.rank))
847
848    @skip_if_lt_x_gpu(2)
849    def test_no_params(self):
850        """
851        Test that device_id and cpu init work if module has no params
852        (they are effective noops, but ensure FSDP does not assume module
853        has parameters during init)
854        """
855        # TODO: override FSDP MT Thread _run to set this instead of here for
856        # every test.
857        torch.cuda.set_device(self.rank)
858        # Test CPU
859        no_params = nn.ReLU()
860        module = FSDP(no_params)
861        # Test CUDA
862        no_params = nn.ReLU().cuda()
863        module = FSDP(no_params)
864        # Test CPU + device_id
865        no_params = nn.ReLU()
866        module = FSDP(no_params, device_id=torch.cuda.current_device())
867        # For modules with no params, wrong device_id will raise error about
868        # inconsistency between compute_device and device_id, since compute_device
869        # is computed as torch.cuda.current_device when there are no params.
870        no_params = nn.ReLU().cuda()
871        context = (
872            (
873                self.assertRaisesRegex(
874                    ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0"
875                )
876            )
877            if self.rank != 0
878            else nullcontext()
879        )
880        with context:
881            FSDP(no_params, device_id=0)
882
883    @skip_if_lt_x_gpu(2)
884    def test_fsdp_same_model_across_ranks(self):
885        """
886        FSDP broadcasts model from rank 0 to ensure it starts off with the same
887        values.
888        """
889
890        class MyModel(nn.Module):
891            def __init__(self, rank):
892                super().__init__()
893                # Seed via rank to make model different across ranks
894                torch.manual_seed(rank)
895                torch.cuda.manual_seed(rank)
896                self.lin = nn.Linear(10, 10, bias=False)
897                self.buffer = nn.Buffer(torch.ones(1) * rank)
898
899        m = MyModel(self.rank).cuda()
900        _assert_module_states(
901            m, process_group=self.process_group, assert_fn=self.assertNotEqual
902        )
903        # Passing sync_module_states into FSDP makes model the same during init.
904        fsdp = FSDP(m, sync_module_states=True)
905        with fsdp.summon_full_params(fsdp):
906            _assert_module_states(
907                fsdp, process_group=self.process_group, assert_fn=self.assertEqual
908            )
909
910        # sync_module_states also works with CPU module with device_id passed in
911        m = MyModel(self.rank)
912        _assert_module_states(
913            m, process_group=self.process_group, assert_fn=self.assertNotEqual
914        )
915        # Passing sync_module_states into FSDP makes model the same during init.
916        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
917        with fsdp.summon_full_params(fsdp):
918            _assert_module_states(
919                fsdp, process_group=self.process_group, assert_fn=self.assertEqual
920            )
921
922    @skip_if_lt_x_gpu(2)
923    def test_homogeneous_attributes(self):
924        """
925        Tests that passing heterogeneous values for attributes designated as
926        homogeneous raises an error.
927        """
928        # Manually construct this list but verify against the global list of
929        # homogeneous attribute names
930        all_attr_name_and_values = [
931            ("_use_orig_params", False, True),
932            ("limit_all_gathers", False, True),
933            ("_use_full_prec_in_eval", False, True),
934        ]
935        self.assertEqual(
936            [
937                attr_name_and_values[0]
938                for attr_name_and_values in all_attr_name_and_values
939            ],
940            HOMOGENEOUS_ATTR_NAMES,
941        )
942
943        self.run_subtests(
944            {"attr_name_and_values": all_attr_name_and_values},
945            self._test_homogeneous_attributes,
946        )
947
948    def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]):
949        model = NestedWrappedModule.init(
950            self.process_group,
951            FSDPInitMode.NO_FSDP,
952            CUDAInitMode.CUDA_BEFORE,
953            {},
954        )
955        attr_name = attr_name_and_values[0]
956
957        if "_use_full_prec_in_eval" == attr_name:
958            model.module[1] = FSDP(model.module[1])
959            os.environ["FSDP_USE_FULL_PREC_IN_EVAL"] = "1"
960            fsdp_model = FSDP(model)
961        else:
962            fsdp_kwargs_inner = {attr_name.lstrip("_"): attr_name_and_values[1]}
963            fsdp_kwargs_outer = {attr_name.lstrip("_"): attr_name_and_values[2]}
964            model.module[1] = FSDP(model.module[1], **fsdp_kwargs_inner)
965            fsdp_model = FSDP(model, **fsdp_kwargs_outer)
966
967        # Run a forward to trigger lazy initialization and the error
968        with self.assertRaisesRegex(
969            ValueError, f"Expects one homogeneous value for {attr_name}"
970        ):
971            inp = fsdp_model.module.get_input(torch.device("cuda"))
972            fsdp_model(*inp)
973
974    @skip_if_lt_x_gpu(2)
975    def test_fsdp_unsupported_module_cls(self):
976        regex = r"FSDP will not all-gather parameters for containers that do not implement forward"
977        model = nn.ModuleList([MLP(8, torch.device("cpu")) for _ in range(3)])
978        with self.assertWarnsRegex(UserWarning, regex):
979            FSDP(model, device_id="cuda")
980        model = nn.ModuleDict(
981            {"1": MLP(8, torch.device("cpu")), "2": MLP(8, torch.device("cpu"))}
982        )
983        with self.assertWarnsRegex(UserWarning, regex):
984            FSDP(model)
985
986
987class TestFSDPMiscWorldSize1(FSDPTestMultiThread):
988    @property
989    def world_size(self) -> int:
990        return 1
991
992    @skip_if_lt_x_gpu(1)
993    def test_world_size_1_sharding_strategy_warning(self):
994        """
995        Tests that FSDP issues a warning when it switches to using ``NO_SHARD``
996        when the world size is 1.
997        """
998        warning_prefix = "FSDP is switching to use `NO_SHARD` instead of"
999        # If the user already passes `NO_SHARD`, then there should not be a
1000        # warning
1001        with warnings.catch_warnings(record=True) as w:
1002            warnings.simplefilter("always")  # trigger all warnings
1003            FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
1004            for warning in w:
1005                self.assertTrue(
1006                    warning.category != UserWarning
1007                    or not str(warning.message).startswith(warning_prefix)
1008                )
1009
1010        # Check that a warning is issued
1011        warning_suffix = " since the world size is 1."
1012        # - Pass `FULL_SHARD` or `None`
1013        expected_regex_full_shard = (
1014            warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
1015        )
1016        with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
1017            FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
1018        with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
1019            FSDP(nn.Linear(3, 3).cuda())
1020        # - Pass `SHARD_GRAD_OP`
1021        expected_regex_shard_grad_op = (
1022            warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
1023        )
1024        with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
1025            FSDP(
1026                nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
1027            )
1028
1029    @skip_if_lt_x_gpu(1)
1030    def test_training_device_mismatch_errors(self):
1031        """
1032        Tests that, when training starts, if FSDP parameters are not on the
1033        expected device, then an informative error is raised. This applies for
1034        both no parameter CPU offloading and parameter CPU offloading.
1035        """
1036        # Incorrectly not moving from CPU -> GPU
1037        model = torch.nn.Linear(10, 10)
1038        fsdp_model = FSDP(model)
1039        inp = torch.randn((2, 10))
1040        with self.assertRaisesRegex(
1041            RuntimeError,
1042            "An FSDP-managed module unexpectedly has parameters on cpu. Make "
1043            "sure to move the module to cuda:0 before training.",
1044        ):
1045            fsdp_model(inp)
1046
1047        # Incorrectly moving from CPU -> GPU
1048        model = torch.nn.Linear(10, 10)
1049        fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
1050        fsdp_model.to(torch.device("cuda"))
1051        inp = torch.randn((2, 10))
1052        with self.assertRaisesRegex(
1053            RuntimeError,
1054            "An FSDP-managed module with parameter CPU offloading enabled has "
1055            "parameters on cuda:0. Make sure to not move the module from CPU "
1056            "when offloading parameters.",
1057        ):
1058            fsdp_model(inp)
1059
1060    @skip_if_lt_x_gpu(2)
1061    def test_unsafe_setattr(self):
1062        """
1063        Tests that the environment variable for using unsafe setattr gates as
1064        expected.
1065        """
1066        self.run_subtests(
1067            {"use_orig_params": [False, True]},
1068            self._test_unsafe_setattr,
1069        )
1070
1071    def _test_unsafe_setattr(self, use_orig_params: bool):
1072        called_setattr_override = False
1073
1074        class SetattrLinear(nn.Module):
1075            def __init__(self, in_dim: int, out_dim: int, device: torch.device) -> None:
1076                super().__init__()
1077                self.weight = nn.Parameter(
1078                    torch.randn((in_dim, out_dim), device=device)
1079                )
1080
1081            def forward(self, x: torch.Tensor) -> torch.Tensor:
1082                return x @ self.weight
1083
1084            def __setattr__(self, name: str, value: Any) -> None:
1085                nonlocal called_setattr_override
1086                called_setattr_override = True
1087                return super().__setattr__(name, value)
1088
1089        # Construct FSDP module without changing any environment variables and
1090        # run forward, which triggers both unsharded and sharded view setting
1091        module = SetattrLinear(5, 5, torch.device("cuda"))
1092        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1093        inp = torch.randn((8, 5), device=torch.device("cuda"))
1094        called_setattr_override = False
1095        fsdp_module(inp)
1096        self.assertTrue(called_setattr_override)
1097
1098        # Repeat with unsafe setattr explicitly enabled
1099        os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
1100        module = SetattrLinear(5, 5, torch.device("cuda"))
1101        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1102        called_setattr_override = False
1103        fsdp_module(inp)
1104        self.assertFalse(called_setattr_override)
1105
1106        # Repeat with unsafe setattr explicitly disabled
1107        os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
1108        module = SetattrLinear(5, 5, torch.device("cuda"))
1109        fsdp_module = FSDP(module, use_orig_params=use_orig_params)
1110        called_setattr_override = False
1111        fsdp_module(inp)
1112        self.assertTrue(called_setattr_override)
1113
1114
1115instantiate_parametrized_tests(TestFSDPMiscMultiThread)
1116instantiate_parametrized_tests(TestFSDPMiscMultiProcess)
1117
1118if __name__ == "__main__":
1119    run_tests()
1120