xref: /aosp_15_r20/external/pytorch/test/distributed/_composable/test_replicate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import os
4from copy import deepcopy
5
6import torch
7import torch.distributed as dist
8import torch.nn.functional as F
9from torch import nn
10from torch.distributed._composable.fsdp import fully_shard
11from torch.distributed._composable.replicate import replicate
12from torch.distributed._tensor import DTensor
13from torch.testing._internal.common_distributed import (
14    MultiProcessTestCase,
15    skip_if_lt_x_gpu,
16)
17from torch.testing._internal.common_utils import run_tests
18
19
20class Net(nn.Module):
21    def __init__(self) -> None:
22        super().__init__()
23        self.fc1 = nn.Linear(2, 2)
24        self.fc2 = nn.Linear(2, 2)
25        self.fc3 = nn.Linear(2, 2)
26
27    def forward(self, x):
28        return self.fc3(self.fc2(self.fc1(x)))
29
30
31class ReplicateStateDictTest(MultiProcessTestCase):
32    def setUp(self) -> None:
33        super().setUp()
34        self._spawn_processes()
35
36    def tearDown(self):
37        super().tearDown()
38        try:
39            os.remove(self.file_name)
40        except OSError:
41            pass
42
43    def _init_pg(self):
44        dist.init_process_group(
45            backend="gloo",
46            rank=self.rank,
47            world_size=self.world_size,
48            store=dist.FileStore(self.file_name, self.world_size),
49        )
50
51    def _check_state_dict_parity(self, sd_1, sd_2):
52        for k1, k2 in zip(sd_1.keys(), sd_2.keys()):
53            self.assertEqual(k1, k2)
54
55        for v1, v2 in zip(sd_1.values(), sd_2.values()):
56            self.assertEqual(v1, v2)
57
58    def test_replicate_single_module_save_load(self):
59        """
60        Tests that replicate() on a single module state_dict
61        matches local module state_dict.
62        """
63        self._init_pg()
64        model = Net()
65        replicate_model = replicate(deepcopy(model))
66        local_sd = model.state_dict()
67        ddp_sd = replicate_model.state_dict()
68        self._check_state_dict_parity(local_sd, ddp_sd)
69
70    def test_replicate_non_root_multiple_save_load(self):
71        """
72        Tests tha replicate() on multiple submodules matches
73        local module state_dict.
74        """
75        self._init_pg()
76        model = Net()
77        replicate_model = deepcopy(model)
78        replicate(replicate_model.fc1)
79        replicate(replicate_model.fc2)
80        replicate(replicate_model.fc3)
81
82        local_sd = model.state_dict()
83        ddp_sd = replicate_model.state_dict()
84        self._check_state_dict_parity(local_sd, ddp_sd)
85
86
87class ReplicateTest(MultiProcessTestCase):
88    @property
89    def world_size(self) -> int:
90        return 2
91
92    def setUp(self) -> None:
93        super().setUp()
94        self._spawn_processes()
95
96    def tearDown(self):
97        super().tearDown()
98        try:
99            os.remove(self.file_name)
100        except OSError:
101            pass
102
103    def _init_pg(self):
104        dist.init_process_group(
105            backend="gloo",
106            rank=self.rank,
107            world_size=self.world_size,
108            store=dist.FileStore(self.file_name, self.world_size),
109        )
110
111    def _compare_module(self, mod, replicate_mod):
112        local_batch_size = 1
113        global_batch_size = self.world_size * local_batch_size
114        input = torch.randn(global_batch_size, 2)
115        target = torch.randn(global_batch_size, 2)
116
117        def step_model(model, input, target):
118            model.train()
119            output = model(input)
120            loss = F.mse_loss(output, target.to(output.device))
121            loss.backward()
122            for param in model.parameters():
123                with torch.no_grad():
124                    param -= param.grad
125                param.grad = None
126
127        for iteration in range(2):
128            step_model(mod, input, target)
129            step_model(
130                replicate_mod,
131                input[
132                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
133                ],
134                target[
135                    self.rank * local_batch_size : (self.rank + 1) * local_batch_size
136                ],
137            )
138
139            self.assertEqual(
140                len(list(mod.parameters())),
141                len(list(replicate_mod.parameters())),
142            )
143            for i, j in zip(mod.parameters(), replicate_mod.parameters()):
144                self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
145
146            # Shuffle the input so that DDP input is different
147            torch.manual_seed(iteration)
148            input = input[torch.randperm(global_batch_size)]
149
150    def test_replicate_single_module(self):
151        self._init_pg()
152        model = Net()
153        replicate_model = replicate(deepcopy(model))
154        self._compare_module(model, replicate_model)
155
156    @skip_if_lt_x_gpu(2)
157    def test_replicate_move_args_kwargs_to_device(self):
158        class MyNet(nn.Module):
159            def __init__(self) -> None:
160                super().__init__()
161                self.a = nn.Linear(2, 2)
162
163            def forward(self, inp, *, kwarg=None):
164                if kwarg is not None:
165                    inp = inp @ kwarg
166                return self.a(inp)
167
168        self._init_pg()
169        torch.cuda.set_device(self.rank)
170        model = MyNet().cuda()
171        replicate(model, device_id=torch.cuda.current_device())
172        # CPU input ensures replicate can move arg and kwargs to device.
173        a, b = torch.randn(2, 2), torch.randn(2, 2)
174        model(a, kwarg=b).sum().backward()
175
176    @skip_if_lt_x_gpu(2)
177    def test_replicate_ignore_module(self):
178        self._init_pg()
179        torch.cuda.set_device(self.rank)
180        # Seed ensures diff input and thus different local grads across ranks.
181        torch.manual_seed(self.rank)
182        torch.cuda.manual_seed(self.rank)
183        model = Net().cuda()
184        replicate(model, ignored_modules=[model.fc1])
185        # CPU input ensures that replicate can move input to GPU as DDP does.
186        inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
187        out = model(inp) * 10
188        out.sum().backward()
189        # FC1 grads should not be synchronized, FC2 and 3 should be.
190        fc1_grad = model.fc1.weight.grad
191        tensor_list = [torch.zeros_like(fc1_grad) for _ in range(dist.get_world_size())]
192        dist.all_gather(tensor_list, fc1_grad)
193        grad, rest = tensor_list[0], tensor_list[1:]
194        for g in rest:
195            self.assertNotEqual(grad, g)
196
197        for dp_grad in [model.fc2.weight.grad, model.fc3.weight.grad]:
198            tensor_list = [
199                torch.zeros_like(dp_grad) for _ in range(dist.get_world_size())
200            ]
201            dist.all_gather(tensor_list, dp_grad)
202            grad, rest = tensor_list[0], tensor_list[1:]
203            for g in rest:
204                self.assertEqual(grad, g)
205
206    def test_replicate_multi_module(self):
207        self._init_pg()
208        model = Net()
209        replicate_model = deepcopy(model)
210        replicate(replicate_model.fc1)
211        replicate(replicate_model.fc2)
212        replicate(replicate_model.fc3)
213        self._compare_module(model, replicate_model)
214
215    def test_replicate_with_kwargs(self):
216        self._init_pg()
217        model = Net()
218        replicate_model = replicate(
219            deepcopy(model), bucket_cap_mb=1, gradient_as_bucket_view=True
220        )
221        self._compare_module(model, replicate_model)
222
223    @skip_if_lt_x_gpu(2)
224    def test_replicate_device_id(self):
225        self._init_pg()
226        model = Net()
227        model_cuda = deepcopy(model).cuda()
228        model_cuda2 = deepcopy(model_cuda)
229        replicate(model, device_id=torch.device("cpu"))
230        # DDP instance is attached in first pre forward
231        model(torch.randn(2, 2))
232        replicate_ddp_weakref = replicate.state(model)._ddp_weakref()
233        # Should be None for CPU training
234        self.assertEqual(None, replicate_ddp_weakref.device_ids)
235
236        replicate(model_cuda, device_id=torch.device(torch.cuda.current_device()))
237        # DDP instance is attached in first pre forward
238        model_cuda(torch.randn(2, 2))
239        replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
240        self.assertEqual([0], replicate_ddp_weakref.device_ids)
241        # Pass in int as device_id
242        replicate(model_cuda2, device_id=int(torch.cuda.current_device()))
243        # DDP instance is attached in first pre forward
244        model_cuda2(torch.randn(2, 2))
245        replicate_ddp_weakref = replicate.state(model_cuda2)._ddp_weakref()
246        self.assertEqual([0], replicate_ddp_weakref.device_ids)
247
248    def test_replicate_wrong_device_id_type(self):
249        self._init_pg()
250        model = Net()
251        with self.assertRaisesRegex(
252            RuntimeError, "Expected device_id to be int or torch.device"
253        ):
254            replicate(model, device_id=[torch.device("cpu")])
255
256
257class ReplicateFullyShardInit(ReplicateTest):
258    @skip_if_lt_x_gpu(2)
259    def test_replicate_fully_shard_init(self):
260        class ToyModel(nn.Module):
261            def __init__(self, dim: int):
262                super().__init__()
263                self.linears = nn.Sequential(
264                    nn.Linear(dim, dim, bias=False),
265                    nn.Linear(dim, dim, bias=False),
266                    nn.Linear(dim, dim, bias=False),
267                )
268                self.proj = nn.Linear(dim, dim, bias=False)
269
270            def forward(self, x: torch.Tensor):
271                y = self.linears(x)
272                y = self.proj(y)
273                return y
274
275        self._init_pg()
276        torch.cuda.set_device(self.rank)
277        dim = 3
278        bz = 2
279        model = ToyModel(dim).cuda()
280        for linear in model.linears:
281            fully_shard(linear)
282        fully_shard(model.linears)
283        replicate(model, device_id=torch.cuda.current_device())
284        for linear in model.linears:
285            self.assertTrue(isinstance(linear.weight, DTensor))
286        inp = torch.rand(bz, dim)
287        # trigger lazy init
288        model(inp).sum()
289        for linear in model.linears:
290            self.assertTrue(isinstance(linear.weight, DTensor))
291
292
293if __name__ == "__main__":
294    run_tests()
295