xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/dynamic_rendezvous.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import inspect
9import logging
10import os
11import pickle
12import socket
13import threading
14import time
15import weakref
16from abc import ABC, abstractmethod
17from dataclasses import dataclass
18from datetime import datetime, timedelta
19from enum import Enum
20from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple
21
22import torch.distributed as dist
23from torch.distributed import Store
24from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
25
26from .api import (
27    RendezvousClosedError,
28    RendezvousError,
29    RendezvousGracefulExitError,
30    RendezvousHandler,
31    RendezvousInfo,
32    RendezvousParameters,
33    RendezvousStateError,
34    RendezvousStoreInfo,
35    RendezvousTimeoutError,
36)
37from .utils import _delay, _PeriodicTimer
38
39
40__all__ = [
41    "RendezvousBackend",
42    "RendezvousTimeout",
43    "RendezvousSettings",
44    "DynamicRendezvousHandler",
45    "create_handler",
46]
47
48logger = logging.getLogger(__name__)
49
50
51def get_method_name(depth=2):
52    if len(inspect.stack()) > depth:
53        return inspect.stack()[depth].function
54    return "no_method_name"
55
56
57Token = Any
58"""Represent an opaque fencing token used by the rendezvous backend."""
59
60
61class RendezvousBackend(ABC):
62    """Represent a backend that holds the rendezvous state."""
63
64    @property
65    @abstractmethod
66    def name(self) -> str:
67        """Get the name of the backend."""
68
69    @abstractmethod
70    def get_state(self) -> Optional[Tuple[bytes, Token]]:
71        """Get the rendezvous state.
72
73        Returns:
74            A tuple of the encoded rendezvous state and its fencing token or
75            ``None`` if no state is found in the backend.
76
77        Raises:
78            RendezvousConnectionError:
79                The connection to the backend has failed.
80            RendezvousStateError:
81                The rendezvous state is corrupt.
82        """
83
84    @abstractmethod
85    def set_state(
86        self, state: bytes, token: Optional[Token] = None
87    ) -> Optional[Tuple[bytes, Token, bool]]:
88        """Set the rendezvous state.
89
90        The new rendezvous state is set conditionally:
91
92          - If the specified ``token`` matches the fencing token stored in the
93            backend, the state will be updated. The new state will be returned
94            to the caller along with its fencing token.
95          - If the specified ``token`` does not match the fencing token stored
96            in the backend, the state won't be updated; instead the existing
97            state along with its fencing token will be returned to the caller.
98          - If the specified ``token`` is ``None``, the new state will be set
99            only if there is no existing state in the backend. Either the new
100            state or the existing state along with its fencing token will be
101            returned to the caller.
102
103        Args:
104            state:
105                The encoded rendezvous state.
106            token:
107                An optional fencing token that was retrieved by a previous call
108                to :py:meth:`get_state` or ``set_state()``.
109
110        Returns:
111            A tuple of the serialized rendezvous state, its fencing token, and
112            a boolean value indicating whether our set attempt succeeded.
113
114        Raises:
115            RendezvousConnectionError:
116                The connection to the backend has failed.
117            RendezvousStateError:
118                The rendezvous state is corrupt.
119        """
120
121
122class RendezvousTimeout:
123    """Hold the timeout configuration of a rendezvous.
124
125    Args:
126        join:
127            The time within which the rendezvous is expected to complete.
128        last_call:
129            An additional wait amount before completing the rendezvous once the
130            rendezvous has the minimum number of required participants.
131        close:
132            The time within which the rendezvous is expected to close after a
133            call to :py:meth:`RendezvousHandler.set_closed` or
134            :py:meth:`RendezvousHandler.shutdown`.
135        keep_alive:
136            The time within which a keep-alive heartbeat is expected to
137            complete.
138    """
139
140    _ZERO = timedelta(0)
141
142    _DEFAULT_TIMEOUTS = {
143        "join": timedelta(seconds=600),
144        "last_call": timedelta(seconds=30),
145        "close": timedelta(seconds=30),
146        "heartbeat": timedelta(seconds=5),
147    }
148
149    _join: timedelta
150    _last_call: timedelta
151    _close: timedelta
152    _heartbeat: timedelta
153
154    def __init__(
155        self,
156        join: Optional[timedelta] = None,
157        last_call: Optional[timedelta] = None,
158        close: Optional[timedelta] = None,
159        heartbeat: Optional[timedelta] = None,
160    ) -> None:
161        self._set_timeouts(
162            join=join, last_call=last_call, close=close, heartbeat=heartbeat
163        )
164
165    @property
166    def join(self) -> timedelta:
167        """Get the join timeout."""
168        return self._join
169
170    @property
171    def last_call(self) -> timedelta:
172        """Get the last call timeout."""
173        return self._last_call
174
175    @property
176    def close(self) -> timedelta:
177        """Get the close timeout."""
178        return self._close
179
180    @property
181    def heartbeat(self) -> timedelta:
182        """Get the keep-alive heartbeat timeout."""
183        return self._heartbeat
184
185    def _set_timeouts(self, **timeouts: Optional[timedelta]):
186        for name, timeout in timeouts.items():
187            if timeout is None:
188                timeout = self._DEFAULT_TIMEOUTS[name]
189            if timeout <= self._ZERO:
190                raise ValueError(f"The {name} timeout ({timeout}) must be positive.")
191            setattr(self, "_" + name, timeout)
192
193
194@dataclass(repr=False, eq=False, frozen=True)
195class RendezvousSettings:
196    """Hold the settings of the rendezvous.
197
198    Attributes:
199        run_id:
200            The run id of the rendezvous.
201        min_nodes:
202            The minimum number of nodes to admit to the rendezvous.
203        max_nodes:
204            The maximum number of nodes to admit to the rendezvous.
205        timeout:
206            The timeout configuration of the rendezvous.
207        keep_alive_interval:
208            The amount of time a node waits before sending a heartbeat to keep
209            it alive in the rendezvous.
210        keep_alive_max_attempt:
211            The maximum number of failed heartbeat attempts after which a node
212            is considered dead.
213    """
214
215    run_id: str
216    min_nodes: int
217    max_nodes: int
218    timeout: RendezvousTimeout
219    keep_alive_interval: timedelta
220    keep_alive_max_attempt: int
221
222
223@dataclass(eq=True, order=True, frozen=True)
224class _NodeDesc:
225    """Describe a node in the rendezvous.
226
227    Attributes:
228        addr:
229            The FQDN of the node or user specified local node address.
230        pid:
231            The id of the process in which the rendezvous handler runs.
232        local_id:
233            A process-wide unique id.
234    """
235
236    addr: str
237    pid: int
238    local_id: int
239
240    def __repr__(self) -> str:
241        return f"{self.addr}_{self.pid}_{self.local_id}"
242
243
244class _NodeDescGenerator:
245    """Generate node descriptors.
246
247    A node descriptor is a combination of an FQDN, a process id, and an auto-
248    incremented integer that uniquely identifies a node in the rendezvous.
249    """
250
251    _lock: threading.Lock
252    _local_id: int
253
254    def __init__(self) -> None:
255        self._lock = threading.Lock()
256
257        # An integer that is incremented with each call to generate().
258        self._local_id = 0
259
260    def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
261        # This method can be called by multiple threads concurrently; therefore,
262        # we must increment the integer atomically.
263        with self._lock:
264            local_id = self._local_id
265
266            self._local_id += 1
267
268        return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id)
269
270
271class _RendezvousState:
272    """Hold the state of a rendezvous.
273
274    Attributes:
275        round:
276            The current round of the rendezvous.
277        complete:
278            A boolean value indicating whether the current round of the
279            rendezvous is complete.
280        deadline:
281            The time at which the current round of the rendezvous will be
282            considered complete if it is still waiting for nodes to join.
283        closed:
284            A boolean value indicating whether the rendezvous is closed.
285        participants:
286            A dictionary of the participants and their corresponding ranks.
287        wait_list:
288            A set of nodes that are waiting to participate in the next round of
289            the rendezvous.
290        redundancy_list:
291            A set of nodes that are redundant in the current round and can join
292            the next rendezvous without triggering re-rendezvous.
293        last_heartbeats:
294            A dictionary containing each node's last heartbeat time.
295    """
296
297    round: int
298    complete: bool
299    deadline: Optional[datetime]
300    closed: bool
301    participants: Dict[_NodeDesc, int]
302    wait_list: Set[_NodeDesc]
303    redundancy_list: Set[_NodeDesc]
304    last_heartbeats: Dict[_NodeDesc, datetime]
305
306    def __init__(self) -> None:
307        self.round = 0
308        self.complete = False
309        self.deadline = None
310        self.closed = False
311        self.participants = {}
312        self.wait_list = set()
313        self.redundancy_list = set()
314        self.last_heartbeats = {}
315
316
317def _remove_participant_epilogue(
318    state: _RendezvousState, settings: RendezvousSettings
319) -> None:
320    if state.complete:
321        # If we do not have any participants left, move to the next round.
322        if not state.participants:
323            msg = "No participants left in the rendezvous, marking rendezvous as incomplete"
324            logger.debug(msg)
325            state.complete = False
326
327            state.round += 1
328    else:
329        if len(state.participants) < settings.min_nodes:
330            msg = (
331                f"Number of participants {len(state.participants)}) less than"
332                f"min_nodes {settings.min_nodes}, clearning deadline in state"
333            )
334            logger.debug(msg)
335            state.deadline = None
336
337
338class _RendezvousStateHolder(ABC):
339    """Hold the shared rendezvous state synced with other nodes."""
340
341    @property
342    @abstractmethod
343    def state(self) -> _RendezvousState:
344        """Get the local state."""
345
346    @abstractmethod
347    def sync(self) -> Optional[bool]:
348        """Read or writes the latest state.
349
350        Returns:
351            A boolean value indicating whether the local state, in case marked
352            as dirty, was successfully synced with other nodes.
353        """
354
355    @abstractmethod
356    def mark_dirty(self) -> None:
357        """Mark the local state as dirty."""
358
359
360class _BackendRendezvousStateHolder(_RendezvousStateHolder):
361    """Hold the rendezvous state synced with other nodes via a backend.
362
363    Args:
364        backend:
365            The rendezvous backend to use.
366        settings:
367            The rendezvous settings.
368        cache_duration:
369            The amount of time, in seconds, to cache the last rendezvous state
370            before requesting it from the backend again.
371    """
372
373    _backend: RendezvousBackend
374    _state: _RendezvousState
375    _settings: RendezvousSettings
376    _cache_duration: int
377    _token: Token
378    _dirty: bool
379    _last_sync_time: float
380    _dead_nodes: List[_NodeDesc]
381
382    def __init__(
383        self,
384        backend: RendezvousBackend,
385        settings: RendezvousSettings,
386        cache_duration: int = 1,
387    ) -> None:
388        self._backend = backend
389        self._state = _RendezvousState()
390        self._settings = settings
391        self._cache_duration = cache_duration
392        self._token = None
393        self._dirty = False
394        self._last_sync_time = -1
395        self._dead_nodes = []
396
397    def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
398        construct_and_record_rdzv_event(
399            name=f"{self.__class__.__name__}.{get_method_name()}",
400            run_id=self._settings.run_id,
401            message=message,
402            node_state=node_state,
403        )
404
405    @property
406    def state(self) -> _RendezvousState:
407        """See base class."""
408        return self._state
409
410    def sync(self) -> Optional[bool]:
411        """See base class."""
412        state_bits: Optional[bytes] = None
413
414        token = None
415
416        has_set: Optional[bool]
417
418        if self._dirty:
419            has_set = False
420
421            state_bits = pickle.dumps(self._state)
422
423            set_response = self._backend.set_state(state_bits, self._token)
424            if set_response is not None:
425                state_bits, token, has_set = set_response
426        else:
427            has_set = None
428
429            if self._cache_duration > 0:
430                # Avoid overloading the backend if we are asked to retrieve the
431                # state repeatedly. Try to serve the cached state.
432                if self._last_sync_time >= max(
433                    time.monotonic() - self._cache_duration, 0
434                ):
435                    return None
436
437            get_response = self._backend.get_state()
438            if get_response is not None:
439                state_bits, token = get_response
440
441        if state_bits is not None:
442            try:
443                self._state = pickle.loads(state_bits)
444            except pickle.PickleError as exc:
445                raise RendezvousStateError(
446                    "The rendezvous state is corrupt. See inner exception for details."
447                ) from exc
448        else:
449            self._state = _RendezvousState()
450
451        if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG):
452            node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
453
454            msg = (
455                f"As part of the sync operation the node(s) {node_list} have been removed from the "
456                f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
457            )
458            self._record(message=msg)
459            logger.debug(msg)
460
461        self._token = token
462
463        self._dirty = False
464
465        self._last_sync_time = time.monotonic()
466
467        self._sanitize()
468
469        return has_set
470
471    def _sanitize(self) -> None:
472        state = self._state
473
474        expire_time = datetime.utcnow() - (
475            self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
476        )
477
478        # Filter out the dead nodes.
479        self._dead_nodes = [
480            node
481            for node, last_heartbeat in state.last_heartbeats.items()
482            if last_heartbeat < expire_time
483        ]
484
485        participant_removed = False
486
487        for dead_node in self._dead_nodes:
488            msg = f"Detected dead node '{dead_node}', removing it from the rendezvous"
489            logger.debug(msg)
490            del state.last_heartbeats[dead_node]
491
492            try:
493                del state.participants[dead_node]
494
495                participant_removed = True
496            except KeyError:
497                pass
498
499            try:
500                state.wait_list.remove(dead_node)
501            except KeyError:
502                pass
503
504            try:
505                state.redundancy_list.remove(dead_node)
506            except KeyError:
507                pass
508
509        if participant_removed:
510            # Common epilogue shared with the _remove_from_participants()
511            # function of _DistributedRendezvousOpExecutor.
512            _remove_participant_epilogue(state, self._settings)
513
514    def mark_dirty(self) -> None:
515        """See base class.
516
517        If the local rendezvous state is dirty, the next sync call will try to
518        write the changes back to the backend. However this attempt might fail
519        if another node, which had the same state, also made changes and wrote
520        them before us.
521        """
522        self._dirty = True
523
524
525class _Action(Enum):
526    """Specifies the possible actions based on the state of the rendezvous."""
527
528    KEEP_ALIVE = 1
529    ADD_TO_PARTICIPANTS = 2
530    ADD_TO_WAIT_LIST = 3
531    ADD_TO_REDUNDANCY_LIST = 4
532    REMOVE_FROM_PARTICIPANTS = 5
533    REMOVE_FROM_WAIT_LIST = 6
534    REMOVE_FROM_REDUNDANCY_LIST = 7
535    MARK_RENDEZVOUS_COMPLETE = 8
536    MARK_RENDEZVOUS_CLOSED = 9
537    SYNC = 10
538    ERROR_CLOSED = 11
539    ERROR_TIMEOUT = 12
540    FINISH = 13
541
542
543class _RendezvousContext:
544    """Holds the context of the rendezvous.
545
546    Attributes:
547        node:
548            The node descriptor associated with the current rendezvous handler
549            instance.
550        state:
551            The current state of the rendezvous.
552        settings:
553            The rendezvous settings.
554    """
555
556    node: _NodeDesc
557    state: _RendezvousState
558    settings: RendezvousSettings
559
560    def __init__(
561        self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
562    ) -> None:
563        self.node = node
564        self.state = state
565        self.settings = settings
566
567
568class _RendezvousOpExecutor(ABC):
569    """Execute rendezvous operations."""
570
571    @abstractmethod
572    def run(
573        self,
574        state_handler: Callable[[_RendezvousContext, float], _Action],
575        deadline: float,
576        update_deadline: Optional[Callable[[timedelta], float]] = None,
577    ) -> None:
578        """Execute a rendezvous operation.
579
580        An operation is run inside a state machine and is expected to transition
581        the rendezvous from one state to another.
582
583        Args:
584            state_handler:
585                A callable that is expected to return the next state transition
586                action based on the current state of the rendezvous.
587            deadline:
588                The time, in seconds, at which the operation will be considered
589                timed-out.
590            update_deadline:
591                Function to generate a new operation deadline if the current
592                node may participate in the next rendezvous.
593        """
594
595
596class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
597    """Execute rendezvous operations using a shared state.
598
599    Args:
600        node:
601            The node descriptor associated with the current rendezvous handler
602            instance.
603        state_holder:
604            The ``RendezvousStateHolder`` to use to sync the rendezvous state
605            with other nodes.
606        settings:
607            The rendezvous settings.
608    """
609
610    _node: _NodeDesc
611    _state: _RendezvousState
612    _state_holder: _RendezvousStateHolder
613    _settings: RendezvousSettings
614
615    def __init__(
616        self,
617        node: _NodeDesc,
618        state_holder: _RendezvousStateHolder,
619        settings: RendezvousSettings,
620    ) -> None:
621        self._node = node
622        self._state_holder = state_holder
623        self._settings = settings
624
625    def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
626        construct_and_record_rdzv_event(
627            name=f"{self.__class__.__name__}.{get_method_name()}",
628            run_id=self._settings.run_id,
629            message=message,
630            node_state=node_state,
631            hostname=self._node.addr,
632            pid=self._node.pid,
633            local_id=self._node.local_id,
634        )
635
636    def run(
637        self,
638        state_handler: Callable[[_RendezvousContext, float], _Action],
639        deadline: float,
640        update_deadline: Optional[Callable[[timedelta], float]] = None,
641    ) -> None:
642        """See base class."""
643        action = None
644        while action != _Action.FINISH:
645            # Reads or writes the latest rendezvous state shared by all nodes in
646            # the rendezvous. Note that our local changes might get overridden
647            # by another node if that node synced its changes before us.
648            has_set = self._state_holder.sync()
649            if has_set is not None:
650                if has_set:
651                    msg = (
652                        f"The node '{self._node}' has successfully synced its local changes with "
653                        f"other nodes in the rendezvous '{self._settings.run_id}'."
654                    )
655                else:
656                    msg = (
657                        f"The node '{self._node}' has a stale state and failed to sync its local "
658                        f"changes with other nodes in the rendezvous '{self._settings.run_id}'."
659                    )
660
661                self._record(message=msg)
662                logger.debug(msg)
663
664            self._state = self._state_holder.state
665
666            ctx = _RendezvousContext(self._node, self._state, self._settings)
667
668            # Determine the next action to take based on the current state of
669            # the rendezvous.
670            action = state_handler(ctx, deadline)
671
672            if action == _Action.FINISH:
673                continue
674
675            if action == _Action.ERROR_CLOSED:
676                raise RendezvousClosedError
677
678            if action == _Action.ERROR_TIMEOUT:
679                raise RendezvousTimeoutError
680
681            if action == _Action.SYNC:
682                # Delay the execution by one second to avoid overloading the
683                # backend if we are asked to poll for state changes.
684                _delay(seconds=1)
685            else:
686                if action == _Action.KEEP_ALIVE:
687                    self._keep_alive()
688                elif action == _Action.ADD_TO_PARTICIPANTS:
689                    self._add_to_participants()
690                elif action == _Action.ADD_TO_WAIT_LIST:
691                    self._add_to_wait_list()
692                elif action == _Action.ADD_TO_REDUNDANCY_LIST:
693                    self._add_to_redundancy_list()
694                elif action == _Action.REMOVE_FROM_PARTICIPANTS:
695                    self._remove_from_participants()
696                elif action == _Action.REMOVE_FROM_WAIT_LIST:
697                    self._remove_from_wait_list()
698                elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST:
699                    self._remove_from_redundancy_list()
700                    # update deadline since the node may participate in rendezvous process
701                    if update_deadline:
702                        deadline = update_deadline(self._settings.timeout.join)
703                elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
704                    self._mark_rendezvous_complete()
705                elif action == _Action.MARK_RENDEZVOUS_CLOSED:
706                    self._mark_rendezvous_closed()
707
708                # Attempt to sync our changes back to other nodes.
709                self._state_holder.mark_dirty()
710
711    def _keep_alive(self) -> None:
712        msg = (
713            f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
714            f"'{self._settings.run_id}'. Pending sync."
715        )
716        self._record(message=msg)
717        logger.debug(msg)
718
719        self._state.last_heartbeats[self._node] = datetime.utcnow()
720
721    def _add_to_participants(self) -> None:
722        msg = (
723            f"The node '{self._node}' added itself to the participants of round "
724            f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
725        )
726        self._record(message=msg)
727        logger.debug(msg)
728
729        state = self._state
730
731        try:
732            state.wait_list.remove(self._node)
733        except KeyError:
734            pass
735
736        # The ranks of the participants will be set once the rendezvous is
737        # complete.
738        state.participants[self._node] = 0
739
740        self._keep_alive()
741
742        if len(state.participants) == self._settings.min_nodes:
743            state.deadline = datetime.utcnow() + self._settings.timeout.last_call
744
745        if len(state.participants) == self._settings.max_nodes:
746            self._mark_rendezvous_complete()
747
748    def _add_to_wait_list(self) -> None:
749        msg = (
750            f"The node '{self._node}' added itself to the wait list of round "
751            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
752        )
753        self._record(message=msg)
754        logger.debug(msg)
755
756        if self._node in self._state.redundancy_list:
757            self._state.redundancy_list.remove(self._node)
758        self._state.wait_list.add(self._node)
759
760        self._keep_alive()
761
762    def _add_to_redundancy_list(self) -> None:
763        msg = (
764            f"The node '{self._node}' added itself to the redundancy list of round "
765            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
766        )
767        self._record(message=msg)
768        logger.debug(msg)
769
770        self._state.redundancy_list.add(self._node)
771
772        self._keep_alive()
773
774    def _remove_from_participants(self) -> None:
775        msg = (
776            f"The node '{self._node}' removed itself from the participants of round "
777            f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
778        )
779        self._record(message=msg)
780        logger.debug(msg)
781
782        state = self._state
783
784        del state.participants[self._node]
785
786        del state.last_heartbeats[self._node]
787
788        # Common epilogue shared with the sanitizer() function of
789        # _BackendRendezvousStateHolder.
790        _remove_participant_epilogue(state, self._settings)
791
792    def _remove_from_wait_list(self) -> None:
793        msg = (
794            f"The node '{self._node}' removed itself from the wait list of round "
795            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
796        )
797        self._record(message=msg)
798        logger.debug(msg)
799
800        self._state.wait_list.remove(self._node)
801
802        del self._state.last_heartbeats[self._node]
803
804    def _remove_from_redundancy_list(self) -> None:
805        msg = (
806            f"The node '{self._node}' removed itself from the redunant list of round "
807            f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
808        )
809        self._record(message=msg)
810        logger.debug(msg)
811
812        self._state.redundancy_list.remove(self._node)
813
814        del self._state.last_heartbeats[self._node]
815
816    def _mark_rendezvous_complete(self) -> None:
817        msg = (
818            f"The node '{self._node}' marked round {self._state.round} of the rendezvous "
819            f"'{self._settings.run_id}' as complete. Pending sync."
820        )
821        self._record(message=msg, node_state=NodeState.SUCCEEDED)
822        logger.debug(msg)
823
824        state = self._state
825
826        state.complete = True
827        state.deadline = None
828
829        # Assign the ranks.
830        for rank, node in enumerate(sorted(state.participants)):
831            state.participants[node] = rank
832
833    def _mark_rendezvous_closed(self) -> None:
834        msg = (
835            f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. "
836            "Pending sync."
837        )
838        self._record(message=msg, node_state=NodeState.SUCCEEDED)
839        logger.debug(msg)
840
841        self._state.closed = True
842
843
844def _should_keep_alive(ctx: _RendezvousContext) -> bool:
845    """Determine whether a keep-alive heartbeat should be sent."""
846    try:
847        last_heartbeat = ctx.state.last_heartbeats[ctx.node]
848    except KeyError:
849        return False
850
851    return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
852
853
854class _RendezvousExitOp:
855    """Represent a rendezvous exit operation."""
856
857    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
858        if ctx.node in ctx.state.participants:
859            if time.monotonic() > deadline:
860                return _Action.ERROR_TIMEOUT
861            return _Action.REMOVE_FROM_PARTICIPANTS
862        return _Action.FINISH
863
864
865class _RendezvousJoinOp:
866    """Represent a rendezvous join operation."""
867
868    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
869        state = ctx.state
870
871        # A closed rendezvous means that it no longer accepts new nodes.
872        if state.closed:
873            if ctx.node in state.redundancy_list:
874                msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous."
875                raise RendezvousGracefulExitError(msg)
876            return _Action.ERROR_CLOSED
877
878        if ctx.node in state.redundancy_list:
879            msg = f"The node {ctx.node} is in redunancy list"
880            logger.debug(msg)
881            # don't apply the timeout logic here, since we want to allow the node to rejoin
882            if len(state.participants) == ctx.settings.max_nodes:
883                if _should_keep_alive(ctx):
884                    return _Action.KEEP_ALIVE
885                else:
886                    return _Action.SYNC
887            else:
888                # transition to waiting state that will respect timeouts.
889                msg = f"The node {ctx.node} is removed from redunancy list"
890                logger.debug(msg)
891                return _Action.REMOVE_FROM_REDUNDANCY_LIST
892
893        is_participant = ctx.node in state.participants
894
895        # If we are part of the rendezvous and it is already complete there is
896        # no further action to take.
897        if state.complete and is_participant:
898            return _Action.FINISH
899
900        now = time.monotonic()
901        if now > deadline:
902            rollback_period = 5  # 5 seconds
903
904            # If we still have time to rollback (a short period on top of the
905            # operation deadline), try to remove ourself from the rendezvous.
906            # It is okay if we can't though as our keep-alive will eventually
907            # expire.
908            if now <= deadline + rollback_period:
909                # If we are part of the rendezvous, it means we couldn't find
910                # enough participants to complete it on time.
911                if is_participant:
912                    return _Action.REMOVE_FROM_PARTICIPANTS
913                # If we are in the wait list, it means we couldn't wait till the
914                # next round of the rendezvous.
915                if ctx.node in state.wait_list:
916                    return _Action.REMOVE_FROM_WAIT_LIST
917            return _Action.ERROR_TIMEOUT
918
919        if state.complete:
920            # If we are here, it means we are not part of the rendezvous. In
921            # case the rendezvous has capacity for additional participants add
922            # ourself to the wait list for the next round.
923            if len(state.participants) < ctx.settings.max_nodes:
924                if ctx.node not in state.wait_list:
925                    return _Action.ADD_TO_WAIT_LIST
926            elif len(state.participants) >= ctx.settings.max_nodes:
927                if (
928                    ctx.node not in state.redundancy_list
929                    and ctx.node not in state.wait_list
930                ):
931                    return _Action.ADD_TO_REDUNDANCY_LIST
932        elif is_participant:
933            # If the rendezvous has enough number of participants including us,
934            # check whether we have passed the rendezvous deadline. If yes,
935            # complete it.
936            if (
937                len(state.participants) >= ctx.settings.min_nodes
938                and len(state.participants) <= ctx.settings.max_nodes
939            ):
940                if cast(datetime, state.deadline) < datetime.utcnow():
941                    msg = (
942                        f"The node '{ctx.node}' marking the rendezvous complete, "
943                        f"quorum established within deadline"
944                    )
945                    logger.debug(msg)
946                    return _Action.MARK_RENDEZVOUS_COMPLETE
947                else:
948                    msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached"
949                    logger.debug(msg)
950            else:
951                msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants"
952                logger.debug(msg)
953        else:
954            # The rendezvous is not complete yet and we are not part of it. Try
955            # to join.
956            return _Action.ADD_TO_PARTICIPANTS
957
958        if _should_keep_alive(ctx):
959            return _Action.KEEP_ALIVE
960
961        # At this point either the rendezvous is not complete, but we are part
962        # of it, which means we have to wait for other participants to join; or
963        # the rendezvous is complete, but we are not part of it, which means we
964        # have to wait for the next round.
965        return _Action.SYNC
966
967
968class _RendezvousCloseOp:
969    """Represent a rendezvous close operation."""
970
971    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
972        if ctx.state.closed:
973            return _Action.FINISH
974        if time.monotonic() > deadline:
975            return _Action.ERROR_TIMEOUT
976        return _Action.MARK_RENDEZVOUS_CLOSED
977
978
979class _RendezvousKeepAliveOp:
980    """Represent a rendezvous keep-alive update operation."""
981
982    def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
983        if _should_keep_alive(ctx):
984            if time.monotonic() > deadline:
985                return _Action.ERROR_TIMEOUT
986            return _Action.KEEP_ALIVE
987        return _Action.FINISH
988
989
990class DynamicRendezvousHandler(RendezvousHandler):
991    """Represent a handler that sets up a rendezvous among a set of nodes."""
992
993    # Static
994    _node_desc_generator = _NodeDescGenerator()
995
996    _this_node: _NodeDesc
997    _settings: RendezvousSettings
998    _backend_name: str
999    _store: Store
1000    _state_holder: _RendezvousStateHolder
1001    _op_executor: _RendezvousOpExecutor
1002    _heartbeat_lock: threading.Lock
1003    _keep_alive_timer: Optional[_PeriodicTimer]
1004
1005    @classmethod
1006    def from_backend(
1007        cls,
1008        run_id: str,
1009        store: Store,
1010        backend: RendezvousBackend,
1011        min_nodes: int,
1012        max_nodes: int,
1013        local_addr: Optional[str] = None,
1014        timeout: Optional[RendezvousTimeout] = None,
1015    ):
1016        """Create a new :py:class:`DynamicRendezvousHandler`.
1017
1018        Args:
1019            run_id:
1020                The run id of the rendezvous.
1021            store:
1022                The C10d store to return as part of the rendezvous.
1023            backend:
1024                The backend to use to hold the rendezvous state.
1025            min_nodes:
1026                The minimum number of nodes to admit to the rendezvous.
1027            max_nodes:
1028                The maximum number of nodes to admit to the rendezvous.
1029            local_addr:
1030                The local node address.
1031            timeout:
1032                The timeout configuration of the rendezvous.
1033        """
1034        # We associate each handler instance with a unique node descriptor.
1035        node = cls._node_desc_generator.generate(local_addr)
1036
1037        settings = RendezvousSettings(
1038            run_id,
1039            min_nodes,
1040            max_nodes,
1041            timeout or RendezvousTimeout(),
1042            keep_alive_interval=timedelta(seconds=5),
1043            keep_alive_max_attempt=3,
1044        )
1045
1046        state_holder = _BackendRendezvousStateHolder(backend, settings)
1047
1048        return cls(node, settings, backend.name, store, state_holder)
1049
1050    def __init__(
1051        self,
1052        node: _NodeDesc,
1053        settings: RendezvousSettings,
1054        backend_name: str,
1055        store: Store,
1056        state_holder: _RendezvousStateHolder,
1057    ) -> None:
1058        if not settings.run_id:
1059            raise ValueError("The run id must be a non-empty string.")
1060
1061        if settings.min_nodes < 1:
1062            raise ValueError(
1063                f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero."
1064            )
1065
1066        if settings.max_nodes < settings.min_nodes:
1067            raise ValueError(
1068                f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal "
1069                f"to the minimum number of nodes ({settings.min_nodes})."
1070            )
1071
1072        self._this_node = node
1073
1074        self._settings = settings
1075
1076        self._backend_name = backend_name
1077
1078        self._store = store
1079
1080        self._state_holder = state_holder
1081
1082        self._op_executor = _DistributedRendezvousOpExecutor(
1083            self._this_node, self._state_holder, self._settings
1084        )
1085
1086        self._heartbeat_lock = threading.Lock()
1087
1088        self._keep_alive_timer = None
1089
1090        # Cached shared store server reference
1091        self._shared_tcp_store_server: Optional[dist.Store] = None
1092
1093        self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None
1094
1095    def _record(
1096        self,
1097        message: str,
1098        node_state: NodeState = NodeState.RUNNING,
1099        rank: Optional[int] = None,
1100    ) -> None:
1101        construct_and_record_rdzv_event(
1102            name=f"{self.__class__.__name__}.{get_method_name()}",
1103            run_id=self._settings.run_id,
1104            message=message,
1105            node_state=node_state,
1106            hostname=self._this_node.addr,
1107            pid=self._this_node.pid,
1108            local_id=self._this_node.local_id,
1109            rank=rank,
1110        )
1111
1112    def _create_tcp_store_server(self, bootstrap_store_info) -> dist.TCPStore:
1113        return dist.TCPStore(
1114            bootstrap_store_info.master_addr,
1115            bootstrap_store_info.master_port,
1116            is_master=True,
1117            multi_tenant=True,
1118        )
1119
1120    @property
1121    def settings(self) -> RendezvousSettings:
1122        """Get the settings of the rendezvous."""
1123        return self._settings
1124
1125    def get_backend(self) -> str:
1126        """See base class."""
1127        return self._backend_name
1128
1129    @property
1130    def use_agent_store(self) -> bool:
1131        """See base class."""
1132        return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1"
1133
1134    def next_rendezvous(self) -> RendezvousInfo:
1135        """See base class."""
1136        msg = (
1137            f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
1138            f"'{self._settings.run_id}'."
1139        )
1140        self._record(message=msg)
1141        logger.info(msg)
1142
1143        try:
1144            self._stop_heartbeats()
1145
1146            # Delay the execution for a small random amount of time if this is our
1147            # first run. This will slightly skew the rendezvous attempts across the
1148            # nodes and reduce the load on the backend.
1149            if self._state_holder.state.round == 0:
1150                _delay(seconds=(0, 0.3))
1151
1152            exit_op = _RendezvousExitOp()
1153            join_op = _RendezvousJoinOp()
1154
1155            deadline = self._get_deadline(self._settings.timeout.join)
1156            self._op_executor.run(exit_op, deadline)
1157            self._op_executor.run(join_op, deadline, self._get_deadline)
1158
1159            self._start_heartbeats()
1160
1161            rank, world_size = self._get_world()
1162            store = self._get_store()
1163
1164        except Exception as e:
1165            self._record(
1166                message=f"{type(e).__name__}: {str(e)}",
1167                node_state=NodeState.FAILED,
1168            )
1169            raise
1170
1171        msg = (
1172            f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
1173            f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
1174            f"{world_size}."
1175        )
1176        self._record(message=msg, rank=rank)
1177        logger.info(msg)
1178
1179        # opt-out option of TCP store sharing
1180        if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1":
1181            bootstrap_store_info = RendezvousStoreInfo.build(
1182                rank, store, local_addr=self._this_node.addr
1183            )
1184            return RendezvousInfo(
1185                store,
1186                rank,
1187                world_size,
1188                bootstrap_store_info,
1189            )
1190
1191        if self._bootstrap_store_info is None:
1192            if isinstance(self._store, dist.TCPStore):
1193                addr = self._store.host
1194                port = self._store.port
1195                self._bootstrap_store_info = RendezvousStoreInfo(
1196                    master_addr=addr, master_port=port
1197                )
1198                if rank == 0:
1199                    self._shared_tcp_store_server = self._store
1200            else:
1201                # If the store is not type of TCPStore start TCPStore server, which requries
1202                # bootstrapping info across ranks
1203                self._bootstrap_store_info = RendezvousStoreInfo.build(
1204                    rank, store, local_addr=self._this_node.addr
1205                )
1206                if rank == 0:
1207                    self._shared_tcp_store_server = self._create_tcp_store_server(
1208                        self._bootstrap_store_info
1209                    )
1210
1211        assert self._bootstrap_store_info is not None
1212        if rank == 0:
1213            assert self._shared_tcp_store_server is not None
1214
1215        return RendezvousInfo(
1216            store,
1217            rank,
1218            world_size,
1219            self._bootstrap_store_info,  # type: ignore[assignment]
1220        )
1221
1222    def is_closed(self) -> bool:
1223        """See base class."""
1224        try:
1225            with self._heartbeat_lock:
1226                self._state_holder.sync()
1227
1228                return self._state_holder.state.closed
1229
1230        except Exception as e:
1231            self._record(
1232                message=f"{type(e).__name__}: {str(e)}",
1233                node_state=NodeState.FAILED,
1234            )
1235            raise
1236
1237    def set_closed(self) -> None:
1238        """See base class."""
1239        try:
1240            with self._heartbeat_lock:
1241                self._close()
1242        except Exception as e:
1243            self._record(
1244                message=f"{type(e).__name__}: {str(e)}",
1245                node_state=NodeState.FAILED,
1246            )
1247            raise
1248
1249    def num_nodes_waiting(self) -> int:
1250        """See base class."""
1251        try:
1252            with self._heartbeat_lock:
1253                self._state_holder.sync()
1254
1255                return len(self._state_holder.state.wait_list)
1256
1257        except Exception as e:
1258            self._record(
1259                message=f"{type(e).__name__}: {str(e)}",
1260                node_state=NodeState.FAILED,
1261            )
1262            raise
1263
1264    def get_run_id(self) -> str:
1265        """See base class."""
1266        return self._settings.run_id
1267
1268    def shutdown(self) -> bool:
1269        """See base class."""
1270        self._stop_heartbeats()
1271
1272        try:
1273            self._close()
1274
1275            return True
1276        except RendezvousError as ex:
1277            msg = (
1278                f"The node '{self._this_node}' has failed to shutdown the rendezvous "
1279                f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}."
1280            )
1281            self._record(message=msg, node_state=NodeState.FAILED)
1282            logger.warning(msg)
1283
1284            return False
1285        except Exception as e:
1286            self._record(
1287                message=f"{type(e).__name__}: {str(e)}",
1288                node_state=NodeState.FAILED,
1289            )
1290            raise
1291
1292    def _close(self) -> None:
1293        op = _RendezvousCloseOp()
1294
1295        deadline = self._get_deadline(self._settings.timeout.close)
1296
1297        self._op_executor.run(op, deadline)
1298
1299        msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'."
1300        self._record(message=msg, node_state=NodeState.SUCCEEDED)
1301        logger.info(msg)
1302
1303    @staticmethod
1304    def _keep_alive_weak(weak_self) -> None:
1305        self = weak_self()
1306        if self is not None:
1307            self._keep_alive()
1308
1309    def _keep_alive(self) -> None:
1310        self._heartbeat_lock.acquire()
1311
1312        op = _RendezvousKeepAliveOp()
1313
1314        deadline = self._get_deadline(self._settings.timeout.heartbeat)
1315
1316        try:
1317            self._op_executor.run(op, deadline)
1318
1319            msg = (
1320                f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
1321                f"'{self._settings.run_id}'."
1322            )
1323            self._record(message=msg)
1324            logger.debug(msg)
1325        except RendezvousError as ex:
1326            msg = (
1327                f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
1328                f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
1329            )
1330            self._record(message=msg, node_state=NodeState.FAILED)
1331            logger.warning(msg)
1332        finally:
1333            self._heartbeat_lock.release()
1334
1335    def _start_heartbeats(self) -> None:
1336        self._keep_alive_timer = _PeriodicTimer(
1337            self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
1338        )
1339
1340        self._keep_alive_timer.set_name(
1341            f"RendezvousKeepAliveTimer_{self._this_node.local_id}"
1342        )
1343
1344        self._keep_alive_timer.start()
1345
1346    def _stop_heartbeats(self) -> None:
1347        if self._keep_alive_timer is None:
1348            return
1349
1350        self._keep_alive_timer.cancel()
1351
1352    def _get_world(self) -> Tuple[int, int]:
1353        state = self._state_holder.state
1354
1355        return state.participants[self._this_node], len(state.participants)
1356
1357    def _wrap_store(self, store: Store) -> Store:
1358        key_prefix = (
1359            f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}"
1360        )
1361
1362        return dist.PrefixStore(key_prefix, store)
1363
1364    def _get_store(self) -> Store:
1365        return self._wrap_store(self._store)
1366
1367    def _get_deadline(self, timeout: timedelta) -> float:
1368        return time.monotonic() + timeout.total_seconds()
1369
1370
1371def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
1372    timeout = params.get_as_int(key + "_timeout")
1373    if timeout is None:
1374        return None
1375    return timedelta(seconds=timeout)
1376
1377
1378def create_handler(
1379    store: Store, backend: RendezvousBackend, params: RendezvousParameters
1380) -> DynamicRendezvousHandler:
1381    """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters.
1382
1383    Args:
1384        store:
1385            The C10d store to return as part of the rendezvous.
1386        backend:
1387            The backend to use to hold the rendezvous state.
1388
1389    +-------------------+------------------------------------------------------+
1390    | Parameter         | Description                                          |
1391    +===================+======================================================+
1392    | join_timeout      | The total time, in seconds, within which the         |
1393    |                   | rendezvous is expected to complete. Defaults to 600  |
1394    |                   | seconds.                                             |
1395    +-------------------+------------------------------------------------------+
1396    | last_call_timeout | An additional wait amount, in seconds, before        |
1397    |                   | completing the rendezvous once the minimum number of |
1398    |                   | nodes has been reached. Defaults to 30 seconds.      |
1399    +-------------------+------------------------------------------------------+
1400    | close_timeout     | The time, in seconds, within which the rendezvous is |
1401    |                   | expected to close after a call to                    |
1402    |                   | :py:meth:`RendezvousHandler.set_closed` or           |
1403    |                   | :py:meth:`RendezvousHandler.shutdown`. Defaults to   |
1404    |                   | 30 seconds.                                          |
1405    +-------------------+------------------------------------------------------+
1406    """
1407    try:
1408        timeout = RendezvousTimeout(
1409            _get_timeout(params, "join"),
1410            _get_timeout(params, "last_call"),
1411            _get_timeout(params, "close"),
1412        )
1413
1414        return DynamicRendezvousHandler.from_backend(
1415            params.run_id,
1416            store,
1417            backend,
1418            params.min_nodes,
1419            params.max_nodes,
1420            params.local_addr,
1421            timeout,
1422        )
1423    except Exception as e:
1424        construct_and_record_rdzv_event(
1425            message=f"{type(e).__name__}: {str(e)}",
1426            run_id=params.run_id,
1427            node_state=NodeState.FAILED,
1428        )
1429        raise
1430