xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/test_fully_shard_extensions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import copy
5import functools
6import threading
7import unittest
8from typing import Any, List, Optional, Tuple, Union
9
10import torch
11import torch.distributed as dist
12import torch.nn as nn
13from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
14from torch.distributed.device_mesh import DeviceMesh
15from torch.testing._internal.common_cuda import TEST_CUDA
16from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
17from torch.testing._internal.common_fsdp import (
18    check_sharded_parity,
19    FSDPTest,
20    FSDPTestMultiThread,
21    MLP,
22)
23from torch.testing._internal.common_utils import run_tests
24from torch.testing._internal.two_tensor import TwoTensor
25
26
27def two_tensor_fsdp_pre_all_gather(
28    self, mesh: DeviceMesh
29) -> Tuple[Tuple[torch.Tensor, ...], Any]:
30    all_gather_inputs = (self.a, self.b)
31    metadata = None
32    return all_gather_inputs, metadata
33
34
35def two_tensor_fsdp_post_all_gather(
36    self,
37    all_gather_outputs: Tuple[torch.Tensor, ...],
38    metadata: Any,
39    param_dtype: torch.dtype,
40    *,
41    out: Optional[torch.Tensor] = None,
42) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
43    assert metadata is None, f"{metadata}"
44    a, b = all_gather_outputs
45    if out is not None:
46        assert isinstance(out, TwoTensor), f"{type(out)}"
47        if a.dtype == param_dtype:
48            assert a.untyped_storage().data_ptr() == out.a.untyped_storage().data_ptr()
49            assert b.untyped_storage().data_ptr() == out.b.untyped_storage().data_ptr()
50        else:
51            assert out.a.dtype == param_dtype, f"{out.a.dtype} {param_dtype}"
52            assert out.b.dtype == param_dtype, f"{out.b.dtype} {param_dtype}"
53            out.a.copy_(a)
54            out.b.copy_(b)
55        return
56    tensors_to_free = (a, b)
57    # If the cast is real, then the all-gather outputs will not alias the
58    # returned `TwoTensor`'s `a` and `b`
59    two_tensor = TwoTensor(a, b).to(param_dtype)
60    return two_tensor, tensors_to_free
61
62
63class TestFullyShardAllGatherExtensionsCommon:
64    @property
65    def world_size(self) -> int:
66        return 2
67
68    @contextlib.contextmanager
69    def _patch_two_tensor_fsdp_all_gather(self):
70        lock = threading.Lock()
71        TwoTensor.fsdp_pre_all_gather = two_tensor_fsdp_pre_all_gather
72        TwoTensor.fsdp_post_all_gather = two_tensor_fsdp_post_all_gather
73        dist.barrier()
74        try:
75            yield
76        finally:
77            dist.barrier()
78            with lock:  # only one thread needs to delete
79                if hasattr(TwoTensor, "fsdp_pre_all_gather"):
80                    delattr(TwoTensor, "fsdp_pre_all_gather")
81                if hasattr(TwoTensor, "fsdp_post_all_gather"):
82                    delattr(TwoTensor, "fsdp_post_all_gather")
83
84    def _init_two_tensor_mlp(self) -> nn.Module:
85        # Disable bias because the reference model will end up with a bias
86        # gradient that is a `TwoTensor`, whereas the FSDP model does not
87        model = nn.Sequential(*[MLP(8, bias=False) for _ in range(3)])
88        for mlp in model:
89            mlp.in_proj.weight = nn.Parameter(
90                TwoTensor(mlp.in_proj.weight, mlp.in_proj.weight.clone())
91            )
92            mlp.out_proj.weight = nn.Parameter(
93                TwoTensor(mlp.out_proj.weight, mlp.out_proj.weight.clone())
94            )
95        return model
96
97
98class TestFullyShardAllGatherExtensionsMultiProcess(
99    TestFullyShardAllGatherExtensionsCommon, FSDPTest
100):
101    @skip_if_lt_x_gpu(2)
102    def test_all_gather_extensions_train_parity(self):
103        with self._patch_two_tensor_fsdp_all_gather():
104            self.run_subtests(
105                {"reshard_after_forward": [True, False]},
106                self._test_all_gather_extensions_train_parity,
107            )
108
109    def _test_all_gather_extensions_train_parity(self, reshard_after_forward: bool):
110        torch.manual_seed(42)
111        model = self._init_two_tensor_mlp()
112        ref_model = copy.deepcopy(model).cuda()
113        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=True)
114        fully_shard_fn = functools.partial(
115            fully_shard, reshard_after_forward=reshard_after_forward
116        )
117        for mlp in model:
118            fully_shard_fn(mlp)
119        fully_shard_fn(model)
120        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
121        check_sharded_parity(self, ref_model, model)
122
123        torch.manual_seed(42 + self.rank + 1)
124        inp = torch.randn((2, 8), device="cuda")
125        for iter_idx in range(10):
126            losses: List[torch.Tensor] = []
127            for _model in (ref_model, model):
128                losses.append(_model(inp).sum())
129                losses[-1].backward()
130                if _model is ref_model:
131                    for param_name, param in _model.named_parameters():
132                        dist.all_reduce(param.grad)
133                        param.grad.detach().div_(self.world_size)
134            self.assertEqual(losses[0], losses[1])
135            check_sharded_parity(self, ref_model, model)
136            for _optim in (ref_optim, optim):
137                _optim.step()
138                _optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
139            check_sharded_parity(self, ref_model, model)
140
141
142class TestFullyShardAllGatherExtensionsMultiThread(
143    TestFullyShardAllGatherExtensionsCommon, FSDPTestMultiThread
144):
145    @property
146    def device(self) -> torch.device:
147        return torch.device("cuda:0")
148
149    @unittest.skipIf(not TEST_CUDA, "no cuda")
150    def test_all_gather_extensions_end_to_end(self):
151        with self._patch_two_tensor_fsdp_all_gather():
152            self.run_subtests(
153                {"reshard_after_forward": [True, False]},
154                self._test_all_gather_extensions_end_to_end,
155            )
156
157    def _test_all_gather_extensions_end_to_end(self, reshard_after_forward: bool):
158        # Check that we can run the meta-device initialization flow
159        with torch.device("meta"):
160            model = self._init_two_tensor_mlp()
161        for param in model.parameters():
162            self.assertEqual(param.device, torch.device("meta"))
163        fully_shard_fn = functools.partial(
164            fully_shard,
165            reshard_after_forward=reshard_after_forward,
166            mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16),
167        )
168        for mlp in model:
169            fully_shard_fn(mlp)
170        fully_shard_fn(model)
171        model.to_empty(device=self.device)
172        for param in model.parameters():
173            nn.init.trunc_normal_(param)
174        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
175
176        # Run a few iterations to check for errors
177        torch.manual_seed(42 + self.rank + 1)
178        inp = torch.randn((2, 8), device="cuda")
179        for _ in range(3):
180            model(inp).sum().backward()
181            optim.step()
182            optim.zero_grad()
183
184    @unittest.skipIf(not TEST_CUDA, "no cuda")
185    def test_all_gather_extensions_monkey_patch(self):
186        # Define a pre/post-all-gather pair that quantizes to bf16 for the
187        # all-gather and de-quantizes back to the parameter dtype
188        def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]:
189            return (self.to(torch.bfloat16),), None
190
191        def fsdp_post_all_gather(
192            self,
193            all_gather_outputs: Tuple[torch.Tensor, ...],
194            metadata: Any,
195            param_dtype: torch.dtype,
196            *,
197            out: Optional[torch.Tensor] = None,
198        ) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
199            (tensor,) = all_gather_outputs
200            assert metadata is None, f"{metadata}"
201            assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
202            if out is not None:
203                out.copy_(tensor)
204                return
205            return tensor.to(param_dtype), (tensor,)
206
207        with torch.device("meta"):
208            model = self._init_two_tensor_mlp()
209        for mlp in model:
210            fully_shard(mlp)
211        fully_shard(model)
212        model.to_empty(device=self.device)
213        for param in model.parameters():
214            nn.init.trunc_normal_(param)
215        # Monkey patch the pre/post-all-gather functions *after* `to_empty()`
216        # since the local tensor objects change from materialization
217        self.assertGreater(sum("weight" in n for n, _ in model.named_parameters()), 0)
218        for param_name, param in model.named_parameters():
219            if "weight" in param_name:
220                local_param = param.to_local()
221                # Monkey patch on the `torch.Tensor` to show that the extension
222                # can work even without a subclass
223                local_param.fsdp_pre_all_gather = fsdp_pre_all_gather
224                local_param.fsdp_post_all_gather = fsdp_post_all_gather
225        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
226
227        # Run a few iterations to check for errors
228        torch.manual_seed(42 + self.rank + 1)
229        inp = torch.randn((2, 8), device="cuda")
230        for _ in range(3):
231            model(inp).sum().backward()
232            optim.step()
233            optim.zero_grad()
234
235
236if __name__ == "__main__":
237    run_tests()
238