xref: /aosp_15_r20/external/pytorch/test/distributed/test_c10d_ucc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import copy
4import logging
5import math
6import operator
7import os
8import random
9import sys
10import tempfile
11from functools import reduce
12
13import torch
14import torch.distributed as c10d
15
16
17if not c10d.is_available() or not c10d.is_ucc_available():
18    print("c10d UCC not available, skipping tests", file=sys.stderr)
19    sys.exit(0)
20
21import test_c10d_common
22from test_c10d_common import (
23    gpus_for_rank,
24    ModuleForDdpCommHook,
25    SparseGradientModule,
26    Task,
27)
28
29import torch.distributed as dist
30import torch.nn.functional as F
31import torch.testing._internal.common_utils as common
32from torch import nn
33from torch.nn.parallel import DistributedDataParallel
34from torch.testing._internal.common_distributed import (
35    MultiProcessTestCase,
36    requires_ucc,
37    skip_if_lt_x_gpu,
38    verify_ddp_error_logged,
39)
40from torch.testing._internal.common_utils import (
41    retry_on_connect_failures,
42    run_tests,
43    skip_but_pass_in_sandcastle,
44    TestCase,
45)
46
47
48def simple_reduce_tests(rank, world_size):
49    tests = [
50        (
51            c10d.ReduceOp.SUM,
52            torch.tensor([rank + 1.0]),
53            torch.tensor([float(world_size * (world_size + 1) / 2)]),
54        ),
55        (
56            c10d.ReduceOp.PRODUCT,
57            torch.tensor([rank + 1.0]),
58            torch.tensor([float(math.factorial(world_size))]),
59        ),
60        (
61            c10d.ReduceOp.MIN,
62            torch.tensor([rank + 1.0]),
63            torch.tensor([1.0]),
64        ),
65        (
66            c10d.ReduceOp.MAX,
67            torch.tensor([rank + 1.0]),
68            torch.tensor([world_size]),
69        ),
70    ]
71
72    # Generate tests for BAND.
73    # The bit that is set changes in every iteration to check
74    # that the output changes accordingly.
75    for i in range(4):
76        vin = rank | (1 << i)
77        vout = 1 << i
78        tests.append(
79            (
80                c10d.ReduceOp.BAND,
81                torch.tensor([vin], dtype=torch.int32),
82                torch.tensor([vout], dtype=torch.int32),
83            ),
84        )
85
86    # Generate tests for BOR.
87    # These emulate a larger world size per iteration by having every
88    # rank contribute multiple values that are pre-OR'ed.
89    for i in range(1, 5):
90        vin = reduce(operator.or_, [rank * i + j for j in range(i)])
91        vout = reduce(operator.or_, range(world_size * i))
92        tests.append(
93            (
94                c10d.ReduceOp.BOR,
95                torch.tensor([vin], dtype=torch.int32),
96                torch.tensor([vout], dtype=torch.int32),
97            ),
98        )
99
100    # Generate tests for XOR.
101    # These emulate a larger world size per iteration by having every
102    # rank contribute multiple values that are pre-XOR'ed.
103    for i in range(1, 5):
104        vin = reduce(operator.xor, [rank * i + j for j in range(i)])
105        vout = reduce(operator.xor, range(world_size * i))
106        tests.append(
107            (
108                c10d.ReduceOp.BXOR,
109                torch.tensor([vin], dtype=torch.int32),
110                torch.tensor([vout], dtype=torch.int32),
111            ),
112        )
113
114    return tests
115
116
117class RendezvousEnvTest(TestCase):
118    @requires_ucc()
119    @retry_on_connect_failures
120    def test_logging_init(self):
121        os.environ["WORLD_SIZE"] = "1"
122        os.environ["MASTER_ADDR"] = "127.0.0.1"
123        os.environ["MASTER_PORT"] = str(common.find_free_port())
124        os.environ["RANK"] = "0"
125
126        previous_handlers = logging.root.handlers
127
128        c10d.init_process_group(backend="ucc", init_method="env://")
129
130        current_handlers = logging.root.handlers
131        self.assertEqual(len(previous_handlers), len(current_handlers))
132        for current, previous in zip(current_handlers, previous_handlers):
133            self.assertEqual(current, previous)
134
135        c10d.destroy_process_group()
136
137
138class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
139    @requires_ucc()
140    @retry_on_connect_failures
141    def test_default_store_timeout_ucc(self):
142        self._test_default_store_timeout("ucc")
143
144
145class ProcessGroupUCCTest(MultiProcessTestCase):
146    def _create_process_group_ucc(self):
147        store = c10d.FileStore(self.file_name, self.world_size)
148        return c10d.ProcessGroupUCC(store, self.rank, self.world_size)
149
150    def setUp(self):
151        super().setUp()
152        self._spawn_processes()
153
154    def tearDown(self):
155        super().tearDown()
156        try:
157            os.remove(self.file_name)
158        except OSError:
159            pass
160
161    @requires_ucc()
162    def test_empty_tensors(self):
163        pg = self._create_process_group_ucc()
164
165        xs = [torch.FloatTensor([])]
166        fut = pg.broadcast(xs).get_future()
167        fut.wait()
168        output = fut.value()
169        self.assertEqual(0, output[0].numel())
170        self.assertEqual(xs[0], output[0], exact_dtype=False)
171
172    # TODO: add error check testing
173
174    def _test_broadcast_basics(self, fn):
175        pg = self._create_process_group_ucc()
176
177        def broadcast(xs, rootRank, rootTensor):
178            opts = c10d.BroadcastOptions()
179            opts.rootRank = rootRank
180            opts.rootTensor = rootTensor
181            fut = pg.broadcast(xs, opts).get_future()
182            fut.wait()
183            return fut.value()
184
185        # Every rank is root once
186        for i in range(self.world_size):
187            # Run with 1 input tensor
188            x = fn(torch.tensor([self.rank]))
189            output = broadcast([x], i, 0)
190            self.assertEqual(torch.tensor([i]), output[0], exact_dtype=False)
191
192            # TODO: UCC currently does not support multi tensor input
193
194        # Test overloaded convenience function
195        x = torch.tensor([self.rank + 1.0])
196        fut = pg.broadcast(x, root=0).get_future()
197        fut.wait()
198        result = fut.value()
199        self.assertEqual(torch.tensor([1.0]), result[0])
200
201    @requires_ucc()
202    def test_broadcast_basics(self):
203        self._test_broadcast_basics(lambda t: t.clone())
204
205    # TODO: test_broadcast_basics_cuda times out locally
206
207    def _test_allreduce_basics(self, fn):
208        pg = self._create_process_group_ucc()
209
210        # Single input tests
211        tests = simple_reduce_tests(self.rank, self.world_size)
212        for op, input, expected in tests:
213            opts = c10d.AllreduceOptions()
214            opts.reduceOp = op
215            tensor = fn(input)
216            fut = pg.allreduce([tensor], opts).get_future()
217            fut.wait()
218            result = fut.value()
219            self.assertEqual(expected, result[0], exact_dtype=False)
220
221        # TODO: UCC currently does not support multi tensor input
222
223        # Test overloaded convenience function (defaults to using sum)
224        x = fn(torch.tensor([self.rank + 1.0]))
225        fut = pg.allreduce(x).get_future()
226        fut.wait()
227        result = fut.value()
228        self.assertEqual(
229            torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
230            result[0],
231        )
232
233    @requires_ucc()
234    def test_allreduce_basics(self):
235        self._test_allreduce_basics(lambda t: t.clone())
236
237    # TODO: test_allreduce_basics_cuda times out locally
238
239    def _test_allgather_basics(self, fn):
240        pg = self._create_process_group_ucc()
241
242        # TODO: Run with N input tensor per rank; for now, UCC only supports single tensor input so N=1
243        for n in [1]:
244            input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
245            output = [
246                [fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
247                for _ in range(n)
248            ]
249            expected_output = [
250                [fn(torch.tensor([i])) for i in range(n * self.world_size)]
251                for _ in range(n)
252            ]
253            fut = pg.allgather(output, input).get_future()
254            fut.wait()
255            result = fut.value()
256            if n == 1:
257                result = [result]
258            self.assertEqual(expected_output, result)
259
260    def test_allgather_basics(self):
261        self._test_allgather_basics(lambda t: t.clone())
262
263    def _test_reduce_basics(self, fn):
264        pg = self._create_process_group_ucc()
265        for op, input, output in simple_reduce_tests(self.rank, self.world_size):
266            for root in range(self.world_size):
267                opts = c10d.ReduceOptions()
268                opts.reduceOp = op
269                opts.rootRank = root
270                tmp = fn(input)
271                fut = pg.reduce([tmp], opts).get_future()
272                fut.wait()
273                result = fut.value()
274                if root == self.rank:
275                    self.assertEqual(output, result[0], exact_dtype=False)
276
277    @requires_ucc()
278    def test_reduce_basics(self):
279        self._test_reduce_basics(lambda t: t.clone())
280
281    # TODO: test_reduce_basics_cuda times out locally
282
283    @requires_ucc()
284    def test_send_recv_all_to_all(self):
285        pg = self._create_process_group_ucc()
286
287        # Preallocate tensors for input/output
288        inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
289        outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
290
291        # Issue sends
292        send_work = []
293        for i in range(self.world_size):
294            if i == self.rank:
295                continue
296            send_work.append(pg.send([inputs[i]], i, 0))
297
298        # Issue recvs
299        recv_work = []
300        for i in range(self.world_size):
301            if i == self.rank:
302                continue
303            recv_work.append(pg.recv([outputs[i]], i, 0))
304
305        # Wait for sends to complete
306        for work in send_work:
307            work.wait()
308            self.assertTrue(work.is_completed())
309
310        # Wait for recvs to complete
311        for work in recv_work:
312            work.wait()
313            self.assertTrue(work.is_completed())
314
315        # Test that every output other than our own contains the respective rank
316        for i in range(self.world_size):
317            if i == self.rank:
318                continue
319            self.assertEqual(torch.tensor([i]), outputs[i])
320
321    # TODO: test_barrier_implies_wait fails with numerical mismatch, will investigate later
322    @skip_but_pass_in_sandcastle("fails with numerical mismatch, skip for now")
323    @requires_ucc()
324    def test_barrier_implies_wait(self):
325        pg = self._create_process_group_ucc()
326
327        # Kick off allreduce operations
328        size = (100, 100)
329        num = 16
330        tensors = [torch.full(size, float(i)) for i in range(num)]
331        for tensor in tensors:
332            # Note: leak the returned work handle
333            pg.allreduce(tensor)
334
335        # Barrier should ensure all previous work has completed
336        pg.barrier().get_future().wait()
337
338        for i, tensor in enumerate(tensors):
339            self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
340
341
342class DistributedDataParallelTest(
343    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
344):
345    def setUp(self):
346        super().setUp()
347        self._spawn_processes()
348
349    def _get_process_group(self):
350        store = self._get_store()
351        c10d.init_process_group(
352            "ucc", store=store, rank=self.rank, world_size=self.world_size
353        )
354        return c10d.distributed_c10d._get_default_group()
355
356    def _test_ucc_backend(
357        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
358    ):
359        process_group = self._get_process_group()
360        self._test_ddp_with_process_group(
361            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
362        )
363
364    @requires_ucc()
365    def test_ucc_backend_cpu_module(self):
366        self._test_ucc_backend([torch.device("cpu")], None)
367
368    @requires_ucc()
369    def test_ucc_backend_cpu_module_grad_is_view(self):
370        self._test_ucc_backend(
371            [torch.device("cpu")], None, gradient_as_bucket_view=True
372        )
373
374    @requires_ucc()
375    @skip_if_lt_x_gpu(2)
376    def test_ucc_backend_1gpu_module_device_ids_integer_list(self):
377        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
378        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
379        self._test_ucc_backend(devices, int_devices)
380
381    @requires_ucc()
382    @skip_if_lt_x_gpu(2)
383    def test_ucc_backend_1gpu_module_device_ids_torch_device_list(self):
384        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
385        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
386        self._test_ucc_backend(devices, devices)
387
388    # TODO: test_ucc_backend_2gpu_module and test_ucc_backend_4gpu_module
389    # require broadcast_coalesced which is not supported by ucc currently
390    @skip_but_pass_in_sandcastle(
391        "requires broadcast coalesced, which is not supported by ucc currently"
392    )
393    @requires_ucc()
394    @skip_if_lt_x_gpu(4)
395    def test_ucc_backend_2gpu_module(self):
396        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
397        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
398        self._test_ucc_backend(devices, None, multi_device=True)
399
400    @skip_but_pass_in_sandcastle(
401        "requires broadcast coalesced, which is not supported by ucc currently"
402    )
403    @requires_ucc()
404    @skip_if_lt_x_gpu(8)
405    def test_ucc_backend_4gpu_module(self):
406        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
407        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
408        self._test_ucc_backend(devices, None, multi_device=True)
409
410    def _test_global_local_unused_params_grad(
411        self, gradient_as_bucket_view=False, static_graph=False
412    ):
413        """
414        By simulating a multi-task training, this test is to make sure:
415        1) DDP does not touch the grad of globally unused parameters.
416        2) DDP does update the grad of locally unused parameters.
417        """
418
419        class GlobalLocalUnusedParamModule(nn.Module):
420            def __init__(self) -> None:
421                super().__init__()
422                self.t0 = Task()
423                self.t1 = Task()
424                self.task_unused = Task()
425
426            def task_parameters(self):
427                return (self.t0.p, self.t1.p, self.task_unused.p)
428
429            def forward(self, x, rank):
430                return self.t0(x) if rank == 0 else self.t1(x)
431
432        def run_and_verify_grad(model):
433            # Run forward
434            output = model(8, self.rank)
435
436            # The grads of all parameters should be None at this point.
437            t0_p, t1_p, task_unused_p = model.module.task_parameters()
438            self.assertIsNone(t0_p.grad)
439            self.assertIsNone(t1_p.grad)
440            self.assertIsNone(task_unused_p.grad)
441
442            # Run backward
443            output.mean().backward()
444
445            # Now locally unused parameter should have grad updated on all ranks.
446            # However the globally unused parameter should still have None grad.
447            self.assertIsNotNone(t0_p.grad)
448            self.assertIsNotNone(t1_p.grad)
449            self.assertIsNone(task_unused_p.grad)
450
451        process_group = self._get_process_group()
452
453        # Test on CPU
454        cpu_model = DistributedDataParallel(
455            GlobalLocalUnusedParamModule().cpu(),
456            process_group=process_group,
457            find_unused_parameters=True,
458            gradient_as_bucket_view=gradient_as_bucket_view,
459            static_graph=static_graph,
460        )
461        run_and_verify_grad(cpu_model)
462
463        # Test on GPU
464        device_id = gpus_for_rank(self.world_size)[self.rank][0]
465        gpu_model = DistributedDataParallel(
466            GlobalLocalUnusedParamModule().to(device_id),
467            device_ids=[device_id],
468            process_group=process_group,
469            find_unused_parameters=True,
470            gradient_as_bucket_view=gradient_as_bucket_view,
471            static_graph=static_graph,
472        )
473        run_and_verify_grad(gpu_model)
474
475    # TODO: times out
476    @skip_but_pass_in_sandcastle("times out")
477    @requires_ucc()
478    @skip_if_lt_x_gpu(2)
479    def test_global_local_unused_params_grad(self):
480        self._test_global_local_unused_params_grad()
481
482    # TODO: times out
483    @skip_but_pass_in_sandcastle("times out")
484    @requires_ucc()
485    @skip_if_lt_x_gpu(2)
486    def test_global_local_unused_params_grad_with_grad_is_view(self):
487        self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
488
489    # TODO: times out
490    @skip_but_pass_in_sandcastle("times out")
491    @requires_ucc()
492    @skip_if_lt_x_gpu(2)
493    def test_global_local_unused_params_grad_with_static_graph(self):
494        self._test_global_local_unused_params_grad(static_graph=True)
495
496    # TODO: times out
497    @skip_but_pass_in_sandcastle("times out")
498    @requires_ucc()
499    @skip_if_lt_x_gpu(2)
500    def test_find_unused_parameters_when_unused_parameters_empty(self):
501        """
502        An empty unused_parameters array does not imply find_unused_parameters =
503        false. This test makes sure that DDP allreduces unused parameters
504        accordingly where the forward pass in some process uses all parameters.
505        This unit test creates a module that uses all parameters in rank = 0, and
506        has unused parameters in other ranks.
507        """
508
509        class FindUnusedParamModule(nn.Module):
510            def __init__(self) -> None:
511                super().__init__()
512                self.t0 = Task()
513                self.t1 = Task()
514
515            def task_parameters(self):
516                return (self.t0.p, self.t1.p)
517
518            def forward(self, x, rank):
519                return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
520
521        def run_and_verify_grad(model):
522            # Run forward
523            output = model(8, self.rank)
524
525            # The grads of all parameters should be None at this point.
526            [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
527
528            # Run backward
529            output.mean().backward()
530
531            # Now locally unused parameter should have grad updated on all ranks.
532            [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
533
534        process_group = self._get_process_group()
535
536        # Test on CPU
537        cpu_model = DistributedDataParallel(
538            FindUnusedParamModule().cpu(),
539            process_group=process_group,
540            find_unused_parameters=True,
541        )
542        run_and_verify_grad(cpu_model)
543
544        # Test on GPU
545        device_id = gpus_for_rank(self.world_size)[self.rank][0]
546        gpu_model = DistributedDataParallel(
547            FindUnusedParamModule().to(device_id),
548            device_ids=[device_id],
549            process_group=process_group,
550            find_unused_parameters=True,
551        )
552        run_and_verify_grad(gpu_model)
553
554    @requires_ucc()
555    def test_ignored_output(self):
556        """
557        Test that the output of a model can be ignored and that there is no
558        implicit requirement that `backward` gets called.
559        """
560        process_group = self._get_process_group()
561
562        class IgnoredOutput(nn.Module):
563            def __init__(self) -> None:
564                super().__init__()
565                self.fc1 = nn.Linear(2, 10, bias=False)
566                self.fc2 = nn.Linear(10, 4, bias=False)
567                self.relu = nn.ReLU()
568
569            def forward(self, x):
570                x = self.relu(self.fc1(x))
571                x = self.relu(self.fc2(x))
572                return F.softmax(x, dim=1)
573
574        model = DistributedDataParallel(
575            IgnoredOutput().float(),
576            process_group=process_group,
577        )
578
579        batch_size = 4
580        criterion = nn.CrossEntropyLoss()
581        input = torch.rand([batch_size, 2], dtype=torch.float)
582        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
583
584        # Run a few iterations where we ignore the output.
585        for _ in range(4):
586            output = model(input)
587            del output
588
589        # Run a few iterations where we use the output.
590        for _ in range(4):
591            output = model(input)
592            loss = criterion(output, target)
593            loss.backward()
594
595    @requires_ucc()
596    def test_ignored_output_with_unused_parameters(self):
597        """
598        Test that the output of a model can be ignored and that there is no
599        implicit requirement that `backward` gets called, if not all model
600        parameters participated in computing the model output.
601        """
602        process_group = self._get_process_group()
603
604        class IgnoredOutputWithUnusedParameters(nn.Module):
605            def __init__(self) -> None:
606                super().__init__()
607                self.fc1 = nn.Linear(2, 10, bias=False)
608                self.fc2 = nn.Linear(10, 4, bias=False)
609                self.fc3 = nn.Linear(4, 4, bias=False)
610                self.relu = nn.ReLU()
611
612            def forward(self, x):
613                x = self.relu(self.fc1(x))
614                x = self.relu(self.fc2(x))
615                return F.softmax(x, dim=1)
616
617        model = DistributedDataParallel(
618            IgnoredOutputWithUnusedParameters().float(),
619            process_group=process_group,
620            find_unused_parameters=True,
621        )
622
623        batch_size = 4
624        criterion = nn.CrossEntropyLoss()
625        input = torch.rand([batch_size, 2], dtype=torch.float)
626        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
627
628        # Run a few iterations where we ignore the output.
629        for _ in range(4):
630            output = model(input)
631            del output
632
633        # Run a few iterations where we use the output.
634        for _ in range(4):
635            output = model(input)
636            loss = criterion(output, target)
637            loss.backward()
638
639    def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
640        mult = 2
641        batch_size = mult * self.world_size
642        criterion = nn.CrossEntropyLoss()
643        input = torch.randint(0, 10, [batch_size, 2])
644        target = torch.randint(0, 10, [batch_size])
645
646        # Run with entire batch against single process version
647        criterion(vanilla_model(input), target).backward()
648
649        # Run with partial batch against multi process version
650        partial_input = input.split(mult)[self.rank]
651        partial_target = target.split(mult)[self.rank]
652        criterion(ddp_model(partial_input), partial_target).backward()
653
654        # Check that the gradients are sparse and identical
655        vanilla_parameter = next(vanilla_model.parameters())
656        ddp_parameter = next(ddp_model.parameters())
657        self.assertEqual(
658            vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce()
659        )
660
661    @requires_ucc()
662    @skip_if_lt_x_gpu(2)
663    def test_save_load_checkpoint(self):
664        dist.init_process_group(
665            "ucc",
666            init_method=f"file://{self.file_name}",
667            world_size=self.world_size,
668            rank=self.rank,
669        )
670
671        class TestModel(nn.Module):
672            def __init__(self) -> None:
673                super().__init__()
674                self.fc1 = nn.Linear(2, 10, bias=False)
675                self.fc2 = nn.Linear(10, 4, bias=False)
676                self.relu = nn.ReLU()
677
678            def forward(self, x):
679                x = self.relu(self.fc1(x))
680                x = self.relu(self.fc2(x))
681                return F.softmax(x, dim=1)
682
683        def train_loop(model, optimizer, iterations):
684            for _ in range(iterations):
685                optimizer.zero_grad()
686                output = model(input)
687                loss = criterion(output, target)
688                loss.backward()
689                optimizer.step()
690
691        device_id = gpus_for_rank(self.world_size)[self.rank][0]
692
693        model_withload = TestModel().float().to(device_id)
694        model_withoutload = TestModel().float().to(device_id)
695
696        ddp_withload = DistributedDataParallel(
697            model_withload,
698            device_ids=[device_id],
699        )
700        ddp_withoutload = DistributedDataParallel(
701            model_withoutload,
702            device_ids=[device_id],
703        )
704
705        # ensure that all the three models start with the same set of parameters. By default they are randomized on construction
706        for p in ddp_withload.parameters():
707            with torch.no_grad():
708                p.zero_()
709        for p in model_withload.parameters():
710            with torch.no_grad():
711                p.zero_()
712        for p in ddp_withoutload.parameters():
713            with torch.no_grad():
714                p.zero_()
715
716        batch_size = 4
717        criterion = nn.CrossEntropyLoss()
718
719        optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
720        optimizer_non_ddp_withload = torch.optim.SGD(
721            model_withload.parameters(), lr=0.001
722        )
723        optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
724
725        input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
726        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
727            device_id
728        )
729
730        # run the model for 6 iterations, with a checkpoint in the middle
731        train_loop(ddp_withload, optimizer_withload, 3)
732
733        # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
734        checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
735        if self.rank == 0:
736            torch.save(ddp_withload.state_dict(), checkpoint_path)
737
738        dist.barrier()
739        map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
740        ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
741
742        for model in [ddp_withload, model_withload]:
743            for p in ddp_withload.parameters():
744                with torch.no_grad():
745                    p.zero_()
746        ddp_withload.load_state_dict(ddp_state_dict)
747        # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
748        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
749            ddp_state_dict, "module."
750        )
751        model_withload.load_state_dict(ddp_state_dict)
752
753        train_loop(ddp_withload, optimizer_withload, 3)
754        train_loop(model_withload, optimizer_non_ddp_withload, 3)
755
756        # re-run the model with the same inputs for 6 iterations with no checkpoint
757        train_loop(ddp_withoutload, optimizer_withoutload, 6)
758
759        for p_withload, p_withoutload, p_non_ddp_withload in zip(
760            ddp_withload.parameters(),
761            ddp_withoutload.parameters(),
762            model_withload.parameters(),
763        ):
764            self.assertEqual(p_withload, p_withoutload)
765            self.assertEqual(p_non_ddp_withload, p_withoutload)
766
767    def _test_sparse_gradients(self, gradient_as_bucket_view=False):
768        process_group = self._get_process_group()
769
770        # Ensure initialized weights and inputs are identical across processes
771        torch.manual_seed(1337)
772
773        vanilla_model = SparseGradientModule()
774        ddp_model = DistributedDataParallel(
775            copy.deepcopy(vanilla_model),
776            process_group=process_group,
777            gradient_as_bucket_view=gradient_as_bucket_view,
778        )
779
780        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
781
782    # TODO: backward pass: input tensor has to be dense
783    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
784    @requires_ucc()
785    def test_sparse_gradients(self):
786        self._test_sparse_gradients()
787
788    # TODO: backward pass: input tensor has to be dense
789    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
790    @requires_ucc()
791    def test_sparse_gradients_grad_is_view(self):
792        self._test_sparse_gradients(gradient_as_bucket_view=True)
793
794    @requires_ucc()
795    def test_ddp_comm_hook_future_passing_cpu(self):
796        """
797        This unit test verifies whether the Future object is passed properly.
798        The callback function creates a Future object and sets a value to it.
799        """
800        process_group = self._get_process_group()
801
802        # Test on CPU
803        cpu_model = DistributedDataParallel(
804            ModuleForDdpCommHook().cpu(), process_group=process_group
805        )
806
807        # Register DDP Communication Hook
808        cpu_model.register_comm_hook(None, self._simple_hook)
809
810        # check whether the grads are equal to what then callback returns.
811        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
812        self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
813
814    def _gpu_model_with_ddp_comm_hook(
815        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
816    ):
817        device_id = gpus_for_rank(self.world_size)[self.rank][0]
818        gpu_model = DistributedDataParallel(
819            ModuleForDdpCommHook().to(device_id),
820            device_ids=[device_id],
821            process_group=process_group,
822            gradient_as_bucket_view=gradient_as_bucket_view,
823        )
824
825        # Register a DDP communication hook if any.
826        if hook is not None:
827            gpu_model.register_comm_hook(state, hook)
828
829        return gpu_model
830
831    @requires_ucc()
832    @skip_if_lt_x_gpu(2)
833    def test_ddp_comm_hook_future_passing_gpu_ucc(self):
834        """
835        This unit test verifies whether the Future object is passed properly using ucc backend.
836        The hook callback function creates a Future object and sets a value to it.
837        """
838        process_group = self._get_process_group()
839
840        # Get GPU model with simple_hook registered.
841        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
842
843        # check whether the grads are equal to what simple_hook's then callback returns.
844        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
845        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
846
847    @requires_ucc()
848    def test_ddp_invalid_comm_hook_init(self):
849        """
850        This unit test makes sure that register_comm_hook properly checks the format
851        of hook defined by user. The Python hook must be callable. This test also
852        checks whether bucket annotation checked properly if defined.
853        """
854        process_group = self._get_process_group()
855
856        model = DistributedDataParallel(
857            ModuleForDdpCommHook(), process_group=process_group
858        )
859
860        with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
861            model.register_comm_hook(state=None, hook=1)
862
863        with self.assertRaisesRegex(
864            ValueError, "bucket annotation should be dist.GradBucket."
865        ):
866
867            def comm_hook(
868                state: object, bucket: int
869            ) -> torch.futures.Future[torch.Tensor]:
870                return torch.futures.Future()
871
872            model.register_comm_hook(state=None, hook=comm_hook)
873
874    @requires_ucc()
875    def test_ddp_invalid_comm_hook_return_type(self):
876        """
877        This test checks whether return annotation checked properly if defined. It also
878        checks whether an internal error is thrown if return type is incorrect and user
879        hasn't specified any return type annotation.
880        """
881        process_group = self._get_process_group()
882
883        model = DistributedDataParallel(
884            ModuleForDdpCommHook(), process_group=process_group
885        )
886
887        expected_err = (
888            "Communication hook: return annotation should be torch.futures.Future"
889        )
890        with self.assertRaisesRegex(
891            ValueError,
892            expected_err,
893        ):
894
895            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
896                return torch.futures.Future()
897
898            model.register_comm_hook(state=None, hook=comm_hook)
899
900        verify_ddp_error_logged(model, expected_err)
901
902        with self.assertRaisesRegex(
903            RuntimeError,
904            "callback must return a torch.futures.Future object, but got",
905        ):
906
907            def comm_hook(state: object, bucket: dist.GradBucket):
908                return 1
909
910            model.register_comm_hook(state=None, hook=comm_hook)
911
912            # Run forward
913            output = model(8, self.rank)
914
915            # Run backward
916            output.mean().backward()
917
918    @requires_ucc()
919    def test_ddp_comm_hook_register_just_once(self):
920        """
921        DDP communication hook can only be registered once. This test validates whether
922        the error is thrown properly when register_comm_hook is called more than once.
923        """
924        process_group = self._get_process_group()
925
926        model = DistributedDataParallel(
927            ModuleForDdpCommHook(), process_group=process_group
928        )
929
930        def dummy_hook(state, bucket):
931            fut = torch.futures.Future()
932            fut.set_result([bucket.buffer()])
933            return fut
934
935        model.register_comm_hook(None, dummy_hook)
936
937        with self.assertRaisesRegex(
938            RuntimeError,
939            "register_comm_hook or register_builtin_comm_hook can only be called once.",
940        ):
941            model.register_comm_hook(None, dummy_hook)
942
943    # TODO: backward pass: input tensor must be dense
944    @skip_but_pass_in_sandcastle("backward pass: input tensor has to be dense")
945    @requires_ucc()
946    def test_ddp_comm_hook_sparse_gradients(self):
947        """
948        Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
949        simple hook that does allreduce and works with ucc backend for this test.
950        """
951        process_group = self._get_process_group()
952
953        # Ensure initialized weights and inputs are identical across processes
954        torch.manual_seed(1337)
955
956        vanilla_model = SparseGradientModule()
957        ddp_model = DistributedDataParallel(
958            copy.deepcopy(vanilla_model),
959            process_group=process_group,
960        )
961
962        def allreduce_hook_ucc(
963            state: object, bucket: dist.GradBucket
964        ) -> torch.futures.Future[torch.Tensor]:
965            def div_by_world_size(fut):
966                # Divide the result by 2 * world_size.
967                return fut.wait()[0] / self.world_size
968
969            # Prepare allreduced grad bucket tensors by running an async work.
970            fut = process_group.allreduce([bucket.buffer()]).get_future()
971            return fut.then(div_by_world_size)
972
973        ddp_model.register_comm_hook(None, allreduce_hook_ucc)
974
975        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
976
977
978class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
979    @property
980    def device(self):
981        return "cpu"
982
983    def setUp(self):
984        super().setUp()
985        self._spawn_processes()
986
987    def tearDown(self):
988        super().tearDown()
989        try:
990            os.remove(self.file_name)
991        except OSError:
992            pass
993
994    @requires_ucc()
995    @skip_if_lt_x_gpu(2)
996    def test_sequence_num_set_default_pg_ucc(self):
997        self._test_sequence_num_set_default_pg(backend="ucc")
998
999    @requires_ucc()
1000    @skip_if_lt_x_gpu(2)
1001    def test_sequence_num_set_ucc_new_group(self):
1002        self._test_sequence_num_set_new_group(backend="ucc")
1003
1004    @skip_if_lt_x_gpu(2)
1005    @requires_ucc()
1006    def test_sequence_num_incremented_ucc_default(self):
1007        self._test_sequence_num_incremented_default_group("ucc")
1008
1009    @skip_if_lt_x_gpu(4)
1010    @requires_ucc()
1011    def test_sequence_num_incremented_ucc_subgroup(self):
1012        if self.world_size < 4:
1013            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
1014        self._test_sequence_num_incremented_subgroup("ucc")
1015
1016    @skip_but_pass_in_sandcastle("Fails on M60")
1017    @requires_ucc()
1018    def test_ucc_barrier_device_ids(self):
1019        store = c10d.FileStore(self.file_name, self.world_size)
1020        c10d.init_process_group(
1021            backend="ucc", rank=self.rank, world_size=self.world_size, store=store
1022        )
1023
1024        with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
1025            c10d.barrier(device_ids=[self.rank])
1026
1027    @skip_but_pass_in_sandcastle("Fails on M60")
1028    @skip_if_lt_x_gpu(2)
1029    @requires_ucc()
1030    def test_ucc_warn_not_in_group(self):
1031        self._test_warn_not_in_group(backend="ucc")
1032
1033    @skip_if_lt_x_gpu(2)
1034    @requires_ucc()
1035    def test_ucc_rank_membership(self):
1036        self._test_rank_membership(backend="ucc")
1037
1038    @skip_if_lt_x_gpu(2)
1039    @requires_ucc()
1040    def test_tensor_dtype_mismatch(self):
1041        self._test_tensor_dtype_mismatch(backend="ucc")
1042
1043    @skip_if_lt_x_gpu(2)
1044    @requires_ucc()
1045    def test_tensor_dtype_complex(self):
1046        self._test_tensor_dtype_complex(backend="ucc")
1047
1048
1049class UccProcessGroupWithDispatchedCollectivesTests(
1050    test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
1051):
1052    @skip_but_pass_in_sandcastle("Fails on M60")
1053    @requires_ucc()
1054    @skip_if_lt_x_gpu(1)
1055    def test_collectives(self):
1056        # includes reduce, broadcast, all_reduce, all_gather, reduce_scatter, barrier, all_to_all, scatter
1057        self._test_collectives(backend="ucc")
1058
1059    @skip_but_pass_in_sandcastle("Fails on M60")
1060    @requires_ucc()
1061    @skip_if_lt_x_gpu(1)
1062    def test_allgather_base(self):
1063        store = dist.FileStore(self.file_name, self.world_size)
1064        dist.init_process_group(
1065            "ucc",
1066            world_size=self.world_size,
1067            rank=self.rank,
1068            store=store,
1069        )
1070        device = "cuda"
1071        tensor = torch.ones(10, 10, device=torch.device(device))
1072        output_tensor = torch.zeros(10, 10, device=torch.device(device))
1073        dist.all_gather_into_tensor(output_tensor, tensor)
1074        self.assertEqual(output_tensor, tensor)
1075
1076
1077if __name__ == "__main__":
1078    assert (
1079        not torch.cuda._initialized
1080    ), "test_distributed must not have initialized CUDA context on main process"
1081
1082    run_tests()
1083