xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/join.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from abc import ABC, abstractmethod
4from types import TracebackType
5from typing import Any, List, NamedTuple, Optional, Type
6
7import torch
8import torch.distributed as dist
9
10
11__all__ = ["JoinHook", "Joinable", "Join"]
12
13
14class JoinHook:
15    r"""
16    This defines a join hook, which provides two entry points in the join context manager.
17
18    Entry points : a main hook, which is called repeatedly while there exists a non-joined
19    process, and a post-hook, which is called once all processes have joined.
20
21    To implement a join hook for the generic join context manager, define a
22    class that inherits from :class:`JoinHook` and override ``main_hook()`` and
23    ``post_hook()`` as appropriate.
24    """
25
26    def main_hook(self) -> None:
27        r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
28
29        Training iteration i.e., in one forward pass, backward pass, and optimizer step.
30        """
31
32    def post_hook(self, is_last_joiner: bool) -> None:
33        r"""
34        Call hook after all processes have joined.
35
36        It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
37
38        Arguments:
39            is_last_joiner (bool): ``True`` if the rank is one of the last to
40                join; ``False`` otherwise.
41        """
42
43
44class Joinable(ABC):
45    r"""
46    This defines an abstract base class for joinable classes.
47
48    A joinable class
49    (inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
50    which returns a :class:`JoinHook` instance, in addition to
51    :meth:`join_device` and :meth:`join_process_group` that return device and
52    process group information, respectively.
53    """
54
55    @abstractmethod
56    def __init__(self) -> None:
57        super().__init__()
58        self._join_config = _JoinConfig.construct_disabled_join_config()
59
60    @abstractmethod
61    def join_hook(self, **kwargs) -> JoinHook:
62        r"""
63        Return a :class:`JoinHook` instance for the given :class:`Joinable`.
64
65        Arguments:
66            kwargs (dict): a :class:`dict` containing any keyword arguments
67                to modify the behavior of the join hook at run time; all
68                :class:`Joinable` instances sharing the same join context
69                manager are forwarded the same value for ``kwargs``.
70        """
71        ...
72
73    @property
74    @abstractmethod
75    def join_device(self) -> torch.device:
76        r"""Return the device from which to perform collective communications needed by the join context manager."""
77        ...
78
79    @property
80    @abstractmethod
81    def join_process_group(self) -> Any:
82        r"""Returns the process group for the collective communications needed by the join context manager itself."""
83        ...
84
85
86class _JoinConfig(NamedTuple):
87    r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
88
89    enable: bool
90    throw_on_early_termination: bool
91    is_first_joinable: bool
92
93    @staticmethod
94    def construct_disabled_join_config():
95        r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
96
97        e.g. if the caller is not in a join context manager.
98        """
99        return _JoinConfig(
100            enable=False, throw_on_early_termination=False, is_first_joinable=False
101        )
102
103
104class Join:
105    r"""
106    This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
107
108    These hooks should shadow the
109    collective communications of non-joined processes to prevent hanging and
110    erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
111    for details about the hook definition.
112
113    .. warning::
114        The context manager requires each participating :class:`Joinable` to
115        call the method :meth:`notify_join_context()` before its own per-
116        iteration collective communications to ensure correctness.
117
118    .. warning::
119        The context manager requires that all ``process_group`` attributes in
120        the :class:`JoinHook` objects are the same. If there are multiple
121        :class:`JoinHook` objects, then the ``device`` of the first is used.
122        The process group and device information is used for checking for non-
123        joined processes and for notifying processes to throw an exception if
124        ``throw_on_early_termination`` is enabled, both of which using an all-
125        reduce.
126
127    Arguments:
128        joinables (List[Joinable]): a list of the participating
129            :class:`Joinable` s; their hooks are iterated over in the given
130            order.
131
132        enable (bool): a flag enabling uneven input detection; setting to
133            ``False`` disables the context manager's functionality and should
134            only be set when the user knows the inputs will not be uneven
135            (default: ``True``).
136
137        throw_on_early_termination (bool): a flag controlling whether to throw an
138            exception upon detecting uneven inputs (default: ``False``).
139
140    Example::
141
142        >>> import os
143        >>> import torch
144        >>> import torch.distributed as dist
145        >>> import torch.multiprocessing as mp
146        >>> # xdoctest: +SKIP
147        >>> import torch.nn.parallel.DistributedDataParallel as DDP
148        >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
149        >>> from torch.distributed.algorithms.join import Join
150        >>>
151        >>> # On each spawned worker
152        >>> def worker(rank):
153        >>>     dist.init_process_group("nccl", rank=rank, world_size=2)
154        >>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
155        >>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
156        >>>     # Rank 1 gets one more input than rank 0
157        >>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
158        >>>     with Join([model, optim]):
159        >>>         for input in inputs:
160        >>>             loss = model(input).sum()
161        >>>             loss.backward()
162        >>>             optim.step()
163        >>>     # All ranks reach here without hanging/erroring
164    """
165
166    def __init__(
167        self,
168        joinables: List[Joinable],
169        enable: bool = True,
170        throw_on_early_termination: bool = False,
171        **kwargs,
172    ):
173        if len(joinables) == 0:
174            raise ValueError("The join context manager requires at least one joinable")
175        self._joinables = joinables
176        self._join_hooks = [
177            joinable.join_hook(**kwargs) for joinable in self._joinables
178        ]
179        self._enable = enable
180        self._throw_on_early_termination = throw_on_early_termination
181        self._set_joinable_configs()
182        self._extract_dist_info()
183
184    def _set_joinable_configs(self) -> None:
185        r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
186        assert len(self._joinables) > 0
187        is_first_joinable = True
188        for joinable in self._joinables:
189            joinable._join_config = _JoinConfig(
190                enable=self._enable,
191                throw_on_early_termination=self._throw_on_early_termination,
192                is_first_joinable=is_first_joinable,
193            )
194            is_first_joinable = False
195
196    def _extract_dist_info(self) -> None:
197        r"""
198        Extract the process group and device information from the joinables.
199
200        If there are multiple joinables, then the context manager uses the
201        first specified device.
202
203        Preconditions:
204            ``self._joinables`` is not ``None`` and is non-empty.
205
206        Raises:
207            ValueError
208                If there are multiple conflicting ``process_group`` attributes
209                among the ``Joinable`` objects.
210        """
211        process_group = None
212        device = None
213        for joinable in self._joinables:
214            if process_group is None:
215                process_group = joinable.join_process_group
216            elif process_group != joinable.join_process_group:
217                raise ValueError(
218                    "Using join context manager with multiple process groups"
219                )
220            if device is None:
221                device = joinable.join_device
222        self._process_group = process_group
223        self._rank = dist.get_rank(self._process_group)
224        self._device = device
225
226    def __enter__(self):
227        ...
228
229    def __exit__(
230        self,
231        type: Optional[Type[BaseException]],
232        value: Optional[BaseException],
233        traceback: Optional[TracebackType],
234    ):
235        r"""
236        Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
237
238        Raises:
239            RuntimeError
240                If ``throw_on_early_termination=True``.
241        """
242        if not self._enable or type:
243            return  # propagate the exception directly if one was raised
244
245        all_procs_joined = False
246        is_last_joiner = True
247
248        i = 0
249        WARN_THRESHOLD = 1000
250        warnings.simplefilter("once")
251
252        while not all_procs_joined:
253            if i > WARN_THRESHOLD:
254                warnings.warn(
255                    "Detected uneven input skew of greater than "
256                    f"{WARN_THRESHOLD}. This means that rank "
257                    f"{self._rank} has at least {WARN_THRESHOLD} "
258                    f"fewer inputs than other currently-active ranks. "
259                    "This level of skew could lead to performance "
260                    "degradation during training."
261                )
262            # Shadow the all-reduce in non-joined processes
263            num_nonjoined_procs = self._get_num_nonjoined_procs()
264            if num_nonjoined_procs == 0:
265                all_procs_joined = True
266            else:
267                if self._throw_on_early_termination:
268                    self._notify_procs_to_terminate()
269
270                # Run main hooks
271                for join_hook in self._join_hooks:
272                    join_hook.main_hook()
273
274                is_last_joiner = False
275                i += 1
276
277        # Run post-hooks
278        for join_hook in self._join_hooks:
279            join_hook.post_hook(is_last_joiner)
280
281    def _get_num_nonjoined_procs(self):
282        r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
283        num_nonjoined_procs = torch.zeros(1, device=self._device)
284        dist.all_reduce(num_nonjoined_procs, group=self._process_group)
285        return num_nonjoined_procs.item()
286
287    def _notify_procs_to_terminate(self):
288        r"""Schedule an all-reduce to notify non-joined processes to terminate.
289
290        Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
291        """
292        ones = torch.ones(1, device=self._device)
293        dist.all_reduce(ones, group=self._process_group)
294        raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
295
296    @staticmethod
297    def notify_join_context(joinable: Joinable):
298        r"""
299        Notifies the join context manager that the calling process has not yet joined.
300
301        Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
302        (i.e. if one process has already joined) and throws an exception if so.
303
304        This method should be called from a :class:`Joinable` object before
305        its per-iteration collective communications. For example, this should
306        be called at the beginning of the forward pass in
307        :class:`DistributedDataParallel`.
308
309        Only the first :class:`Joinable` object passed into the context
310        manager performs the collective communications in this method, and
311        for the others, this method is vacuous.
312
313        Arguments:
314            joinable (Joinable): the :class:`Joinable` object calling this
315                method.
316
317        Returns:
318            An async work handle for the all-reduce meant to notify the context
319            manager that the process has not yet joined if ``joinable`` is the
320            first one passed into the context manager; ``None`` otherwise.
321        """
322        assert hasattr(joinable, "_join_config"), (
323            f"Check that the {type(joinable)} constructor calls the "
324            "``Joinable`` constructor"
325        )
326
327        join_config = joinable._join_config
328        # First joinable is responsible for the collective communications
329        if not join_config.is_first_joinable or not join_config.enable:
330            return None
331
332        device = joinable.join_device
333        process_group = joinable.join_process_group
334
335        # Schedule an all-reduce to indicate that the caller has not yet joined
336        ones = torch.ones(1, device=device)
337        work = dist.all_reduce(ones, group=process_group, async_op=True)
338
339        if join_config.throw_on_early_termination:
340            # Check if uneven inputs have been detected
341            zeros = torch.zeros(1, device=device)
342            dist.all_reduce(zeros, group=process_group)
343            should_throw = zeros.item()
344            if should_throw:
345                raise RuntimeError(
346                    "Detected at least one rank that exhausted inputs. "
347                    "Throwing across all ranks."
348                )
349        return work
350