xref: /aosp_15_r20/external/pytorch/test/distributed/algorithms/test_join.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: distributed"]
2
3import contextlib
4import os
5import sys
6from typing import Any, Optional
7
8import torch
9import torch.distributed as dist
10
11
12if not dist.is_available():
13    print("Distributed not available, skipping tests", file=sys.stderr)
14    sys.exit(0)
15
16from torch.distributed.algorithms.join import Join, Joinable, JoinHook
17from torch.testing._internal.common_distributed import (
18    MultiProcessTestCase,
19    require_n_gpus_for_nccl_backend,
20)
21from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
22
23
24if TEST_WITH_DEV_DBG_ASAN:
25    print(
26        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27        file=sys.stderr,
28    )
29    sys.exit(0)
30
31BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
32WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
33
34# Constants used for testing post-hooks
35BEFORE_CONSTANT = 41
36AFTER_CONSTANT = 42
37
38
39class AllReducerJoinHook(JoinHook):
40    r"""
41    Join hook for :class:`AllReducer`.
42
43    Arguments:
44        allreducer (AllReducer): the :class:`AllReducer` object using this
45            hook.
46        num_allreduces (int): the number of all-reduces to shadow per
47            iteration.
48        run_post_hook (bool): a flag enabling the post-hook logic.
49    """
50
51    def __init__(self, allreducer, num_allreduces, run_post_hook):
52        self.allreducer = allreducer
53        self.num_allreduces = num_allreduces
54        self.run_post_hook = run_post_hook
55
56    def main_hook(self):
57        r"""
58        Shadows each all-reduce; the number of all-reduces is passed into the
59        constructor as ``num_allreduces``.
60        """
61        device = self.allreducer.device
62        for _ in range(self.num_allreduces):
63            t = torch.zeros(1, device=device)
64            dist.all_reduce(t)
65
66    def post_hook(self, is_last_joiner: bool):
67        r"""
68        Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from
69        the last joiner to all other processes.
70        """
71        if not self.run_post_hook:
72            return
73        rank = dist.get_rank(self.allreducer.process_group)
74        common_rank = self.allreducer.find_common_rank(rank, is_last_joiner)
75        device = self.allreducer.device
76        if rank == common_rank:
77            self.allreducer.post_hook_tensor = torch.tensor(
78                [AFTER_CONSTANT], device=device
79            )
80        dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank)
81
82
83class AllReducer(Joinable):
84    r"""
85    Example :class:`Joinable` that performs some number of all-reduces as its
86    per-iteration collective communication.
87    """
88
89    def __init__(self, device, process_group):
90        super().__init__()
91        self.device = device
92        self.process_group = process_group
93        self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device)
94
95    def __call__(self, num_allreduces=1):
96        r"""
97        All-reduces a dim-1 one tensor ``num_allreduces``-many times, and
98        returns the total result.
99        """
100        Join.notify_join_context(self)
101        device = self.device
102        total = 0
103        for _ in range(num_allreduces):
104            t = torch.ones(1, device=device)
105            dist.all_reduce(t)
106            total += t.item()
107        return total
108
109    def join_hook(self, **kwargs) -> JoinHook:
110        r"""
111        Returns a join hook that shadows some number of all-reduces; by default,
112        this number is 1.
113        """
114        num_allreduces = kwargs.get("num_allreduces", 1)
115        run_post_hook = kwargs.get("run_post_hooks", False)
116        return AllReducerJoinHook(self, num_allreduces, run_post_hook)
117
118    @property
119    def join_device(self) -> torch.device:
120        return self.device
121
122    @property
123    def join_process_group(self) -> Any:
124        return self.process_group
125
126    def find_common_rank(self, rank, to_consider):
127        r"""
128        Returns the max rank of the ones to consider over the process group.
129        """
130        common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
131        dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
132        common_rank = common_rank.item()
133        assert common_rank >= 0
134        return common_rank
135
136
137class TestJoin(MultiProcessTestCase):
138    r"""Test cases for the generic join context."""
139
140    def setUp(self):
141        super().setUp()
142        os.environ["WORLD_SIZE"] = str(self.world_size)
143        os.environ["BACKEND"] = BACKEND
144        self._spawn_processes()
145
146    @property
147    def device(self):
148        return (
149            torch.device(self.rank)
150            if BACKEND == dist.Backend.NCCL
151            else torch.device("cpu")
152        )
153
154    @property
155    def world_size(self):
156        return WORLD_SIZE
157
158    @property
159    def process_group(self):
160        return dist.group.WORLD
161
162    def tearDown(self):
163        try:
164            dist.destroy_process_group()
165        except AssertionError:
166            pass
167        try:
168            os.remove(self.file_name)
169        except OSError:
170            pass
171
172    def dist_init(self, rank, world_size, backend=BACKEND):
173        store = dist.FileStore(self.file_name, world_size)
174        return dist.init_process_group(
175            backend=backend, store=store, rank=rank, world_size=world_size
176        )
177
178    def construct_uneven_inputs(self, base, offset, device=None):
179        r"""
180        Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs.
181        """
182        if device is None:
183            device = self.device
184        return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)]
185
186    def construct_even_inputs(self, base, device=None):
187        r"""Returns even inputs: each rank gets ``base`` inputs."""
188        if device is None:
189            device = self.device
190        return [torch.zeros(1, device=device) for _ in range(base)]
191
192    @property
193    def base_num_inputs(self):
194        r"""Base number of inputs to be used by all ranks."""
195        return 3
196
197    @property
198    def offset(self):
199        r"""Rank i gets i * ``offset`` additional inputs."""
200        return 1
201
202    def _test_join_base(
203        self,
204        uneven_inputs: bool,
205        num_joinables: int,
206        enable: bool,
207        throw_on_early_termination: bool,
208        num_allreduces: int,
209        run_post_hooks: bool,
210        expected_total: Optional[int] = None,
211    ):
212        r"""
213        Skeleton for all :class:`Join` tests.
214
215        Arguments:
216            uneven_inputs (bool): ``True`` to use uneven inputs; ``False``
217                otherwise.
218            num_joinables (int): number of :class:`AllReducer` s to construct.
219            enable (bool): ``True`` to enable the join context manager;
220                ``False`` otherwise.
221            throw_on_early_termination (bool): ``True`` to raise an exception
222                upon detecting uneven inputs; ``False`` otherwise.
223            num_allreduces (int): number of all-reduces to perform per input.
224            run_post_hooks (bool): ``True`` to run post-hooks; ``False``
225                otherwise.
226            expected_total (Optional[int]): ``None`` to not check the expected
227                all-reduce total; otherwise, the expected total; default is
228                ``None``.
229        """
230        self.dist_init(self.rank, self.world_size)
231
232        allreducers = [
233            AllReducer(self.device, self.process_group) for _ in range(num_joinables)
234        ]
235        for allreducer in allreducers:
236            self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT)
237
238        inputs = (
239            self.construct_uneven_inputs(self.base_num_inputs, self.offset)
240            if uneven_inputs
241            else self.construct_even_inputs(self.base_num_inputs)
242        )
243        allreduce_total = 0
244
245        # Expect a `RuntimeError` if `throw_on_early_termination=True`
246        # Rank 0 exhausts its inputs first
247        expected_msg = (
248            "Rank 0 exhausted all inputs."
249            if self.rank == 0
250            else "Detected at least one rank that exhausted inputs. "
251            "Throwing across all ranks."
252        )
253        with self.assertRaisesRegex(
254            RuntimeError, expected_msg
255        ) if throw_on_early_termination else contextlib.nullcontext():
256            with Join(
257                allreducers,
258                enable=enable,
259                throw_on_early_termination=throw_on_early_termination,
260                num_allreduces=num_allreduces,
261                run_post_hooks=run_post_hooks,
262            ):
263                for _ in inputs:
264                    for allreducer in allreducers:
265                        allreduce_total += allreducer(num_allreduces)
266
267        if throw_on_early_termination:
268            return
269
270        # Check `expected_total` if not `None`
271        if expected_total:
272            self.assertEqual(allreduce_total, expected_total)
273
274        # All `AllReduce` instances should receive the updated
275        # `post_hook_tensor` from the last-joined process
276        if run_post_hooks:
277            for allreducer in allreducers:
278                self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT)
279
280    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
281    def test_single_joinable_main_hooks(self):
282        r"""Tests the main hooks of a single :class:`Joinable`."""
283        num_joinables = 1
284        num_allreduces = 1
285        run_post_hooks = False
286        # Non-joined processes all-reduce a 1, so this rank's all-reduce total
287        # should be precisely equal to the total number of inputs processed
288        # before it joined
289        expected_total = self.world_size * self.base_num_inputs
290        # Rank i runs for i additional iterations
291        for num_joined in range(1, self.rank + 1):
292            expected_total += (self.world_size - num_joined) * self.offset
293
294        self._test_join_base(
295            uneven_inputs=True,
296            num_joinables=num_joinables,
297            enable=True,
298            throw_on_early_termination=False,
299            num_allreduces=num_allreduces,
300            run_post_hooks=run_post_hooks,
301            expected_total=expected_total,
302        )
303
304    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
305    def test_single_joinable_post_hooks(self):
306        r"""Tests the post-hooks of a single :class:`Joinable`."""
307        num_joinables = 1
308        num_allreduces = 0  # set to 0 to skip the main hooks
309        run_post_hooks = False
310
311        self._test_join_base(
312            uneven_inputs=True,
313            num_joinables=num_joinables,
314            enable=True,
315            throw_on_early_termination=False,
316            num_allreduces=num_allreduces,
317            run_post_hooks=run_post_hooks,
318            expected_total=None,
319        )
320
321    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
322    def test_single_joinable(self):
323        r"""
324        Tests the main hooks and post-hooks of a single :class:`Joinable`
325        together.
326
327        This combines ``test_single_joinable_main_hooks()`` and
328        ``test_single_joinable_post_hooks()`` into a single test to ensure that
329        main hooks and post-hooks operate correctly together.
330        """
331        num_joinables = 1
332        num_allreduces = 1
333        run_post_hooks = True
334
335        expected_total = self.world_size * self.base_num_inputs
336        for num_joined in range(1, self.rank + 1):
337            expected_total += (self.world_size - num_joined) * self.offset
338
339        self._test_join_base(
340            uneven_inputs=True,
341            num_joinables=num_joinables,
342            enable=True,
343            throw_on_early_termination=False,
344            num_allreduces=num_allreduces,
345            run_post_hooks=run_post_hooks,
346            expected_total=expected_total,
347        )
348
349    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
350    def test_multiple_joinables(self):
351        r"""
352        Tests the main hooks and post-hooks of multiple :class:`Joinable` s
353        together.
354
355        This generalizes ``test_single_joinable()`` to multiple
356        :class:`Joinable` s.
357        """
358        num_joinables = 3
359        num_allreduces = 1
360        run_post_hooks = True
361
362        expected_total = self.world_size * self.base_num_inputs
363        for num_joined in range(1, self.rank + 1):
364            expected_total += (self.world_size - num_joined) * self.offset
365        # The expected total is now multiplied by a factor of `NUM_JOINABLES`
366        expected_total *= num_joinables
367
368        self._test_join_base(
369            uneven_inputs=True,
370            num_joinables=num_joinables,
371            enable=True,
372            throw_on_early_termination=False,
373            num_allreduces=num_allreduces,
374            run_post_hooks=run_post_hooks,
375            expected_total=expected_total,
376        )
377
378    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
379    def test_single_joinable_disable(self):
380        r"""Tests ``enable=False`` for a single :class:`Joinable`."""
381        num_joinables = 1
382        num_allreduces = 1
383        uneven_inputs = False
384        enable = False
385        run_post_hooks = False
386
387        expected_total = self.world_size * self.base_num_inputs
388
389        self._test_join_base(
390            uneven_inputs=uneven_inputs,
391            num_joinables=num_joinables,
392            enable=enable,
393            throw_on_early_termination=False,
394            num_allreduces=num_allreduces,
395            run_post_hooks=run_post_hooks,
396            expected_total=expected_total,
397        )
398
399    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
400    def test_multiple_joinable_disable(self):
401        r"""
402        Tests ``enable=False`` for multiple :class:`Joinable` s.
403
404        This generalizes ``test_single_joinable_disable`` to multiple
405        :class:`Joinable` s.
406        """
407        num_joinables = 3
408        num_allreduces = 1
409        uneven_inputs = False
410        enable = False
411        run_post_hooks = False
412
413        expected_total = self.world_size * self.base_num_inputs * num_joinables
414
415        self._test_join_base(
416            uneven_inputs=uneven_inputs,
417            num_joinables=num_joinables,
418            enable=enable,
419            throw_on_early_termination=False,
420            num_allreduces=num_allreduces,
421            run_post_hooks=run_post_hooks,
422            expected_total=expected_total,
423        )
424
425    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
426    def test_single_joinable_throw(self):
427        r"""
428        Tests ``throw_on_early_termination=True`` for a single
429        :class:`Joinable`.
430        """
431        num_joinables = 1
432        num_allreduces = 1
433        throw_on_early_termination = True
434        run_post_hooks = False
435
436        self._test_join_base(
437            uneven_inputs=True,
438            num_joinables=num_joinables,
439            enable=True,
440            throw_on_early_termination=throw_on_early_termination,
441            num_allreduces=num_allreduces,
442            run_post_hooks=run_post_hooks,
443            expected_total=None,
444        )
445
446    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
447    def test_multiple_joinables_throw(self):
448        r"""
449        Tests ``throw_on_early_termination=True`` for multiple
450        :class:`Joinable` s together.
451
452        This generalizes ``test_single_joinable_throw`` to multiple
453        :class:`Joinable` s.
454        """
455        num_joinables = 3
456        num_allreduces = 1
457        throw_on_early_termination = True
458        run_post_hooks = False
459
460        self._test_join_base(
461            uneven_inputs=True,
462            num_joinables=num_joinables,
463            enable=True,
464            throw_on_early_termination=throw_on_early_termination,
465            num_allreduces=num_allreduces,
466            run_post_hooks=run_post_hooks,
467            expected_total=None,
468        )
469
470    @require_n_gpus_for_nccl_backend(WORLD_SIZE, BACKEND)
471    def test_join_kwargs(self):
472        r"""
473        Tests passing keyword arguments to the context manager.
474        """
475        num_joinables = 1
476        num_allreduces = 2
477        run_post_hooks = False
478
479        expected_total = self.world_size * self.base_num_inputs
480        for num_joined in range(1, self.rank + 1):
481            expected_total += (self.world_size - num_joined) * self.offset
482        # The expected total is now multiplied by a factor of `NUM_ALLREDUCES`
483        expected_total *= num_allreduces
484
485        self._test_join_base(
486            uneven_inputs=True,
487            num_joinables=num_joinables,
488            enable=True,
489            throw_on_early_termination=False,
490            num_allreduces=num_allreduces,
491            run_post_hooks=run_post_hooks,
492            expected_total=expected_total,
493        )
494
495
496if __name__ == "__main__":
497    run_tests()
498