xref: /aosp_15_r20/external/pytorch/test/distributed/fsdp/test_fsdp_comm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import sys
4from contextlib import nullcontext
5from enum import auto, Enum
6from typing import List, Optional
7from unittest.mock import patch
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12from torch import distributed as dist
13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
15from torch.distributed.fsdp.wrap import ModuleWrapPolicy
16from torch.nn.parallel.distributed import DistributedDataParallel as DDP
17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18from torch.testing._internal.common_fsdp import (
19    CUDAInitMode,
20    FSDPInitMode,
21    FSDPTest,
22    MLP,
23    NestedWrappedModule,
24    TransformerWithSharedParams,
25)
26from torch.testing._internal.common_utils import (
27    instantiate_parametrized_tests,
28    parametrize,
29    run_tests,
30    TEST_WITH_DEV_DBG_ASAN,
31)
32
33
34if not dist.is_available():
35    print("Distributed not available, skipping tests", file=sys.stderr)
36    sys.exit(0)
37
38if TEST_WITH_DEV_DBG_ASAN:
39    print(
40        "Skip dev-asan as torch + multiprocessing spawn have known issues",
41        file=sys.stderr,
42    )
43    sys.exit(0)
44
45
46class PassType(Enum):
47    __order__ = "FWD BWD"
48    FWD = auto()
49    BWD = auto()
50
51
52class TestCommunication(FSDPTest):
53    """Tests ``FullyShardedDataParallel``'s collective communication usage."""
54
55    def _init_model(
56        self,
57        nested_model: bool,
58        sharding_strategy: ShardingStrategy,
59        device: torch.device,
60    ):
61        fsdp_kwargs = {"sharding_strategy": sharding_strategy}
62        if nested_model:
63            model = NestedWrappedModule.init(
64                self.process_group,
65                FSDPInitMode.RECURSIVE,
66                CUDAInitMode.CUDA_AFTER,
67                fsdp_kwargs,
68            )
69            fsdp_model: FSDP = FSDP(
70                model,
71                self.process_group,
72                **fsdp_kwargs,
73            ).to(device)
74        else:
75            fsdp_model: FSDP = TransformerWithSharedParams.init(
76                self.process_group,
77                FSDPInitMode.RECURSIVE,
78                CUDAInitMode.CUDA_BEFORE,
79                fsdp_kwargs,
80            )
81        return fsdp_model
82
83    def _run_iter(self, fsdp_model, batch, use_no_sync: bool):
84        """Runs an iteration inside or outside the ``no_sync()`` context."""
85        context = fsdp_model.no_sync() if use_no_sync else nullcontext()
86        with context:
87            output = fsdp_model(*batch)
88            loss = fsdp_model.module.get_loss(batch, output)
89            loss.backward()
90
91    def _get_ref_num_reduce_scatters(
92        self,
93        num_fsdp: int,
94        in_no_sync: bool,
95    ) -> int:
96        """Returns the reference number of reduce-scatters for an iteration
97        in the ``no_sync()`` context."""
98        return num_fsdp if not in_no_sync else 0
99
100    def _get_ref_num_all_gathers(
101        self,
102        num_fsdp: int,
103        sharding_strategy: Optional[ShardingStrategy],
104        is_first_iter: bool,
105        is_last_iter_no_sync: bool,
106    ) -> int:
107        """Returns the reference number of all-gathers in an iteration, summing
108        over the forward and backward passes."""
109        return sum(
110            self._get_ref_num_all_gathers_in_pass(
111                num_fsdp,
112                sharding_strategy,
113                pass_type,
114                is_first_iter,
115                is_last_iter_no_sync,
116            )
117            for pass_type in PassType
118        )
119
120    def _get_ref_num_all_gathers_in_pass(
121        self,
122        num_fsdp: int,
123        sharding_strategy: Optional[ShardingStrategy],
124        pass_type: PassType,
125        is_first_iter: bool,
126        is_last_iter_no_sync: bool,
127    ):
128        """Returns the reference number of all-gathers for a given setting."""
129        if sharding_strategy is None:
130            sharding_strategy = ShardingStrategy.FULL_SHARD  # default
131        # Forward pass:
132        if (
133            pass_type == PassType.FWD
134            and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
135            and is_last_iter_no_sync
136        ):
137            # Modules do not free the full parameters in the last
138            # iteration's backward pass if it was in `no_sync()`
139            num_all_gathers = 0
140        elif pass_type == PassType.FWD:
141            # Otherwise, all modules all-gather the full parameters in the
142            # forward pass
143            num_all_gathers = num_fsdp
144        # Backward pass:
145        elif (
146            pass_type == PassType.BWD
147            and sharding_strategy == ShardingStrategy.FULL_SHARD
148        ):
149            # Root does not free the full parameters at the end of the
150            # forward pass
151            num_all_gathers = num_fsdp - 1
152        elif (
153            pass_type == PassType.BWD
154            and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
155        ):
156            # Modules do not free the full parameters at the end of the
157            # forward pass
158            num_all_gathers = 0
159        else:
160            assert 0, (
161                f"Unsupported: add a branch for pass_type={pass_type} "
162                f"is_first_iter={is_first_iter} "
163                f"is_last_iter_no_sync={is_last_iter_no_sync} "
164                f"sharding_strategy={sharding_strategy}"
165            )
166        if is_first_iter and pass_type == PassType.FWD:
167            # With execution order validation, on the first iteration, we have
168            # an additional two all-gathers before every actual all-gather in
169            # the forward pass
170            num_all_gathers *= 3
171        return num_all_gathers
172
173    def _print_ref_num_all_gathers_in_pass(
174        self,
175        num_fsdp: int,
176        sharding_strategy: ShardingStrategy,
177        pass_type: PassType,
178        is_first_iter: bool,
179        is_last_iter_no_sync: bool,
180    ):
181        """Helper method for printing the number of all-gathers for a specific
182        setting. This may be helpful since the branching is complex."""
183        if self.rank != 0:
184            return  # only print on one rank
185        num_all_gathers = self._get_ref_num_all_gathers_in_pass(
186            num_fsdp,
187            sharding_strategy,
188            pass_type,
189            is_first_iter,
190            is_last_iter_no_sync,
191        )
192        print(
193            f"Pass: {pass_type}\n"
194            f"Is First Iteration: {is_first_iter}\n"
195            f"Sharding Strategy: {sharding_strategy}\n"
196            f"Last iteration in `no_sync()`: {is_last_iter_no_sync}\n"
197            f"Number of all-gathers: {num_all_gathers}"
198        )
199
200    @skip_if_lt_x_gpu(2)
201    @parametrize("nested_model", [False, True])
202    @parametrize("use_no_sync", [False, True])
203    @parametrize("sharding_strategy", [ShardingStrategy.SHARD_GRAD_OP, None])
204    def test_communication(
205        self,
206        nested_model: bool,
207        use_no_sync: bool,
208        sharding_strategy: Optional[ShardingStrategy],
209    ):
210        """
211        Tests FSDP's communication cost in terms of calls to collective
212        communication primitives (i.e. all-gather and reduce-scatter).
213
214        Arguments:
215            nested_model (bool): If ``True``, uses ``NestedWrappedModule``,
216                which has nested FSDP instances; if ``False``, uses the default
217                model, which does not have nested FSDP instances.
218            use_no_sync (bool): If ``True``, runs some iterations inside the
219                ``no_sync()`` context manager to accumulate gradients, followed
220                by some iterations outside the context manager; if ``False``,
221                only runs some iterations outside the context manager.
222            sharding_strategy (Optional[ShardingStrategy]): Configures the
223                FSDP algorithm.
224        """
225        # Enable execution order checking
226        dist.set_debug_level(dist.DebugLevel.DETAIL)
227        # Initialize the model and inputs
228        device = torch.device("cuda")
229        fsdp_model = self._init_model(nested_model, sharding_strategy, device)
230        batch = fsdp_model.module.get_input(device)
231
232        # Count the number of FSDP instances that manage parameters since the
233        # number of collectives are a function of this number
234        num_fsdp = sum(
235            (isinstance(m, FSDP) and len(m.params) > 0) for m in fsdp_model.modules()
236        )
237
238        # If `use_no_sync=True`, we run `num_iters` iterations inside
239        # `no_sync()` followed by `num_iters` iterations outside `no_sync()`,
240        # and if `use_no_sync=False`, we only run `num_iters` iterations
241        # outside `no_sync()`
242        num_iters = 3
243        with patch(
244            "torch.distributed.all_gather_into_tensor"
245        ) as mock_all_gather, patch(
246            "torch.distributed.reduce_scatter_tensor"
247        ) as mock_reduce_scatter:
248
249            def reset_mocks():
250                mock_all_gather.reset_mock()
251                mock_reduce_scatter.reset_mock()
252
253            # Check the communication cost when using `no_sync()`
254            if use_no_sync:
255                for i in range(num_iters):
256                    reset_mocks()
257                    self._run_iter(fsdp_model, batch, use_no_sync=True)
258                    num_all_gathers = mock_all_gather.call_count
259                    num_reduce_scatters = mock_reduce_scatter.call_count
260                    ref_num_all_gathers = self._get_ref_num_all_gathers(
261                        num_fsdp,
262                        sharding_strategy,
263                        is_first_iter=i == 0,
264                        is_last_iter_no_sync=i > 0,
265                    )
266                    ref_num_reduce_scatters = self._get_ref_num_reduce_scatters(
267                        num_fsdp,
268                        in_no_sync=True,
269                    )
270                    self.assertEqual(num_all_gathers, ref_num_all_gathers)
271                    self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters)
272            # Check the normal communication cost (when not using `no_sync()`)
273            for i in range(num_iters):
274                reset_mocks()
275                self._run_iter(fsdp_model, batch, use_no_sync=False)
276                num_all_gathers = mock_all_gather.call_count
277                num_reduce_scatters = mock_reduce_scatter.call_count
278                ref_num_all_gathers = self._get_ref_num_all_gathers(
279                    num_fsdp,
280                    sharding_strategy,
281                    is_first_iter=not use_no_sync and i == 0,
282                    is_last_iter_no_sync=use_no_sync and i == 0,
283                )
284                ref_num_reduce_scatters = self._get_ref_num_reduce_scatters(
285                    num_fsdp,
286                    in_no_sync=False,
287                )
288                self.assertEqual(num_all_gathers, ref_num_all_gathers)
289                self.assertEqual(num_reduce_scatters, ref_num_reduce_scatters)
290
291
292class TestExplicitUnshard(FSDPTest):
293    @property
294    def world_size(self) -> int:
295        return min(torch.cuda.device_count(), 2)
296
297    @skip_if_lt_x_gpu(2)
298    @parametrize("use_orig_params", [False, True])
299    def test_unshard_async(self, use_orig_params: bool):
300        class ReduceModule(nn.Module):
301            def __init__(self, dim: int, group: dist.ProcessGroup):
302                super().__init__()
303                self.group = group
304                self.weight = nn.Parameter(torch.randn(dim, dim))
305
306            def forward(self, x: torch.Tensor):
307                y = F.relu(x @ self.weight)
308                # NOTE: This all-reduce is not differentiable and is included
309                # to exercise the overlap.
310                work = dist.all_reduce(y, group=self.group, async_op=True)
311                return y, work
312
313        class MLPs(nn.Module):
314            def __init__(self, dim: int):
315                super().__init__()
316                self.mlp1 = MLP(dim)
317                self.mlp2 = MLP(dim)
318                self.mlp3 = MLP(dim)
319
320            def forward(self, ys: List[torch.Tensor], works: List[dist.Work]):
321                (y1, y2, y3), (work1, work2, work3) = ys, works
322                work1.wait()
323                z1 = self.mlp1(y1)
324                work2.wait()
325                z2 = self.mlp2(y2)
326                work3.wait()
327                z3 = self.mlp3(y3)
328                return z1 + z2 + z3
329
330        class ReduceModel(nn.Module):
331            def __init__(self, dim: int, group: dist.ProcessGroup):
332                super().__init__()
333                self.reduce_module1 = ReduceModule(dim, group)
334                self.reduce_module2 = ReduceModule(dim, group)
335                self.reduce_module3 = ReduceModule(dim, group)
336                self.mlps = MLPs(dim)
337
338            def forward(self, x: torch.Tensor):
339                y1, work1 = self.reduce_module1(x)
340                if isinstance(self.mlps.mlp1, FSDP):
341                    self.mlps.mlp1._unshard(async_op=True)
342                y2, work2 = self.reduce_module2(x)
343                if isinstance(self.mlps.mlp2, FSDP):
344                    self.mlps.mlp2._unshard(async_op=True)
345                y3, work3 = self.reduce_module3(x)
346                if isinstance(self.mlps.mlp3, FSDP):
347                    self.mlps.mlp3._unshard(async_op=True)
348                return self.mlps([y1, y2, y3], [work1, work2, work3])
349
350        group = self.process_group
351        batch_size, dim = 2, 8
352        torch.manual_seed(42)
353        ref_model = DDP(ReduceModel(dim, group).cuda(), device_ids=[self.rank])
354        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
355
356        torch.manual_seed(42)
357        model = ReduceModel(dim, group)
358        model.mlps = FSDP(
359            model.mlps,
360            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
361            auto_wrap_policy=ModuleWrapPolicy((MLP,)),
362            device_id=self.rank,
363            use_orig_params=use_orig_params,
364        )
365        model.mlps.check_is_root()
366        mlp_params = set(model.mlps.parameters())
367        mlp_param_names = {n for n, p in model.named_parameters() if p in mlp_params}
368        DDP._set_params_and_buffers_to_ignore_for_model(model, mlp_param_names)
369        model = DDP(model.cuda(), device_ids=[self.rank])
370        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
371
372        torch.manual_seed(42 + self.rank + 1)
373        inp = torch.randn((batch_size, dim), device="cuda")
374
375        for _ in range(10):
376            losses: List[torch.Tensor] = []
377            for _model, _optim in ((ref_model, ref_optim), (model, optim)):
378                losses.append(_model(inp).sum())
379                losses[-1].backward()
380                _optim.step()
381                _optim.zero_grad()
382            self.assertEqual(losses[0], losses[1])
383            model.module.mlps._wait_unshard_streams_on_current_stream()
384
385
386instantiate_parametrized_tests(TestCommunication)
387instantiate_parametrized_tests(TestExplicitUnshard)
388
389if __name__ == "__main__":
390    run_tests()
391