xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/etcd_rendezvous.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import json
11import logging
12import sys
13import threading
14import time
15from typing import Optional
16
17import etcd  # type: ignore[import]
18
19from torch.distributed.elastic.rendezvous import (
20    RendezvousClosedError,
21    RendezvousError,
22    RendezvousHandler,
23    RendezvousInfo,
24    RendezvousParameters,
25    RendezvousStoreInfo,
26    RendezvousTimeoutError,
27)
28
29from .etcd_store import cas_delay, EtcdStore
30from .utils import parse_rendezvous_endpoint
31
32
33__all__ = [
34    "EtcdRendezvousRetryableFailure",
35    "EtcdRendezvousRetryImmediately",
36    "EtcdRendezvousHandler",
37    "EtcdRendezvous",
38    "create_rdzv_handler",
39]
40
41_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s")
42_log_handler = logging.StreamHandler(sys.stderr)
43_log_handler.setFormatter(_log_fmt)
44
45logger = logging.getLogger(__name__)
46logger.propagate = False
47logger.setLevel(logging.INFO)
48logger.addHandler(_log_handler)
49
50
51# Retryable failure exception means the we were too late to make
52# a desired state transition (e.g. because of a race condition),
53# and should now restart from the beginning.
54# A small delay is recommended to avoid spamming Etcd.
55class EtcdRendezvousRetryableFailure(Exception):
56    pass
57
58
59# Similar to retryable failure, but the new state we observed suggests we
60# can re-try immediately, i.e. without a need for "safety delay".
61class EtcdRendezvousRetryImmediately(Exception):
62    pass
63
64
65# Default timeout for the rendezvous.
66_DEFAULT_TIMEOUT: int = 600  # 10 minutes
67
68# Additional waiting time after reaching the minimum number of nodes
69# in case the rendezvous is elastic (min != max).
70_DEFAULT_LAST_CALL_TIMEOUT: int = 30  # 30 seconds
71
72# Various constants used internally in EtcdRendezvous
73CONST_ETCD_SETUP_TTL = 5
74CONST_ETCD_FROZEN_TTL = 10
75CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10
76
77# Ephemeral node TTL for worker's keep-alive key:
78CONST_WORKER_KEEPALIVE_TTL = 10
79
80# TTL for the ephemeral run_id-specific directory. All rendezvous state data
81# for a specific run_id (job instance) is contained within directory.
82# Its only role is to clean-up rendezvous data from old runs (for the case when
83# etcd server is persistent), and has no affect on correctness, but should be
84# larger than any timeouts that a worker process is expected to survive:
85CONST_RUNID_SUBROOT_TTL = 7200  # 2 hours
86
87
88class EtcdRendezvousHandler(RendezvousHandler):
89    """
90    Implements a
91    :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface
92    backed by
93    :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`.
94    ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to
95    use and to pass implementation specific configurations to the rendezvous
96    module. The basic etcd rendezvous configuration URL looks like the following
97    ::
98
99     etcd://<etcd_address>:<port>/<job_id>?min_workers=<min_workers>&max_workers=<max_workers>  # noqa: W605
100
101     -- example --
102
103     etcd://localhost:2379/1234?min_workers=1&max_workers=3
104
105    The URL above is interpreted as follows:
106
107    1. Use the rendezvous handler that is registered with the ``etcd``
108       scheme
109    2. The ``etcd`` endpoint to use is ``localhost:2379``
110    3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to
111       share a common etcd server for multiple jobs so long as the
112       ``job_ids`` are guaranteed to be unique). Note that the job id can be
113       any string (e.g. does not need to be a number) as long as it is
114       unique.
115    4. ``min_workers=1`` and ``max_workers=3`` specifies a range for
116       membership size - Torch Distributed Elastic starts running the job as
117       long as the cluster size is greater than or equal to ``min_workers``
118       and admits up to ``max_workers`` into the cluster.
119
120    Below are a full list of the parameters that can be passed to etcd
121    rendezvous:
122
123    +--------------------------------------------+--------------------------+
124    | Parameter                                  | Description              |
125    +============================================+==========================+
126    | min_workers                                | minimum number of        |
127    |                                            | workers for the          |
128    |                                            | rendezvous to be valid   |
129    +--------------------------------------------+--------------------------+
130    | max_workers                                | maximum number of        |
131    |                                            | workers to admit         |
132    +--------------------------------------------+--------------------------+
133    | timeout                                    | total timeout within     |
134    |                                            | which next_rendezvous is |
135    |                                            | expected to succeed      |
136    |                                            | (default 600s)           |
137    +--------------------------------------------+--------------------------+
138    | last_call_timeout                          | additional wait amount   |
139    |                                            | ("last call") after min  |
140    |                                            | number of workers has    |
141    |                                            | been reached (defaults   |
142    |                                            | to 30s)                  |
143    +--------------------------------------------+--------------------------+
144    | etcd_prefix                                | path prefix (from etcd   |
145    |                                            | root), inside which all  |
146    |                                            | etcd nodes will be       |
147    |                                            | created (defaults to     |
148    |                                            | ``/torchelastic/p2p``)   |
149    +--------------------------------------------+--------------------------+
150    """
151
152    def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
153        """
154        Args:
155            rdzv_impl: the implementation of the rendezvous
156            local_addr: the local address of the current node
157        """
158
159        self._rdzv_impl = rdzv_impl
160        self._local_addr = local_addr
161
162    def __del__(self):
163        # TODO: look into using weakref here instead.
164        del self._rdzv_impl
165
166    def get_backend(self) -> str:
167        return "etcd"
168
169    def next_rendezvous(self):
170        rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()
171
172        logger.info("Creating EtcdStore as the c10d::Store implementation")
173        store = self._rdzv_impl.setup_kv_store(rdzv_version)
174
175        bootstrap_store_info = RendezvousStoreInfo.build(
176            rank, store, local_addr=self._local_addr
177        )
178        return RendezvousInfo(store, rank, world_size, bootstrap_store_info)
179
180    def is_closed(self):
181        try:
182            _, state = self._rdzv_impl.get_rdzv_state()
183            return state["status"] == "closed"
184        except etcd.EtcdKeyNotFound:
185            # No rendezvous state, so it cannot be closed.
186            return False
187
188    def set_closed(self):
189        self._rdzv_impl.set_closed()
190
191    def num_nodes_waiting(self):
192        try:
193            _, state = self._rdzv_impl.get_rdzv_state()
194            if state["status"] == "final":
195                return state["num_workers_waiting"]
196        except etcd.EtcdKeyNotFound:
197            pass
198        return 0
199
200    def get_run_id(self) -> str:
201        return self._rdzv_impl._run_id
202
203    def shutdown(self) -> bool:
204        try:
205            self.set_closed()
206            return True
207        except BaseException as e:
208            logger.warning("Shutdown failed. Error occurred: %s", str(e))
209            return False
210
211
212# TODO: we should probably handle a few additional errors,
213# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are
214# only relevant for multi-node Etcd ensemble. A simple retry would work,
215# but is verbose to add everywhere. Consider wrapping the client calls
216# into auto-retry for these errors?
217#
218class EtcdRendezvous:
219    """A rendezvous implementation that uses `etcd <https://etcd.io/>`__ as the backend store."""
220
221    def __init__(
222        self,
223        client,
224        prefix,
225        run_id,
226        num_min_workers,
227        num_max_workers,
228        timeout,
229        last_call_timeout,
230    ):
231        self.client = client
232        logger.info("Etcd machines: %s", self.client.machines)
233
234        self._prefix = prefix
235        self._run_id = run_id
236        self._num_min_workers = num_min_workers
237        self._num_max_workers = num_max_workers
238        self._timeout = timeout
239        self._last_call_timeout = last_call_timeout
240
241        # For cleaning up TTL refresher threads (for ephemeral keys)
242        self._lease_run_id_stop = None
243        self._lease_this_rank_stop = None
244
245        if not self._prefix.endswith("/"):
246            self._prefix += "/"
247
248        # Setup a permanent prefix dir, if didn't exist
249        if self._prefix != "/":
250            self.create_path_if_not_exists(self._prefix)
251
252        # Lease a "sub-root" node specific to this job instance (run_id)
253        self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL)
254        self._lease_run_id_stop = self.setup_lease_renewal(
255            self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL
256        )
257
258        # Subdir for all rendezvous work
259        self.create_path_if_not_exists(self.get_path("/rdzv"))
260
261        # Create a rendezvous version counter, if doesn't exist
262        try:
263            self.client.write(
264                key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False
265            )
266        except etcd.EtcdAlreadyExist:
267            pass
268
269    def __del__(self):
270        # TODO: look into using weakref here instead.
271        if self._lease_run_id_stop is not None:
272            self._lease_run_id_stop.set()
273
274        if self._lease_this_rank_stop is not None:
275            self._lease_this_rank_stop.set()
276
277    def rendezvous_barrier(self):
278        """
279        Main entry point for next rendezvous.
280
281        This method is blocking until rendezvous succeeds or a timeout occurs.
282
283        Returns:
284             ``(rdzv_version, rank, world_size)``
285
286        Raises:
287            RendezvousTimeoutError - timeout waiting for rendezvous
288            RendezvousClosedError - rendezvous is or was closed while waiting
289            RendezvousError - other persistent errors that
290             render the rendezvous non-retryable
291        """
292        self._rendezvous_deadline = time.time() + self._timeout
293        while True:
294            if time.time() > self._rendezvous_deadline:
295                raise RendezvousTimeoutError
296
297            logger.info("Attempting to join next rendezvous")
298            try:
299                # Dis-own our lease in the previous rendezvous, if exists
300                if self._lease_this_rank_stop is not None:
301                    self._lease_this_rank_stop.set()
302
303                return self.init_phase()
304
305            except EtcdRendezvousRetryImmediately:
306                # The type of failure suggests we can retry without delay
307                pass
308
309            except EtcdRendezvousRetryableFailure:
310                # In case of retryable failure, wait a small delay
311                # to avoid spamming etcd
312                time.sleep(1)
313
314            except RendezvousTimeoutError:
315                logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler")
316                raise
317
318            except RendezvousClosedError:
319                logger.info(
320                    "Rendezvous for run_id=%s was observed to be closed", self._run_id
321                )
322                raise
323
324            except RendezvousError:
325                raise
326
327            except Exception as e:
328                # In case of a general exception, wait a small delay
329                # to avoid spamming etcd
330                # FIXME: there are a few things that fall under this like
331                # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
332                logger.info("Rendezvous attempt failed, will retry. Reason: %s", e)
333                time.sleep(1)
334
335    def init_phase(self):
336        """
337        Initially, the rendezvous state is expected to be one of:
338
339        1. empty (non-existent) - in this case we try to create a new one.
340        2. joinable - we try to join it.
341        3. final - we announce ourselves as waiting, and go into monitoring mode
342
343        Any other state is considered transitional, and will be retried after
344        a short delay.
345
346        Returns:
347            ``(rdzv_version, rank, world_size)``
348
349        Raises:
350            RendezvousClosedError - current rendezvous was/is closed
351            EtcdRendezvousRetryableFailure - observed some intermediate
352             state, which is best handled by retrying later
353        """
354        try:
355            active_version = self.try_create_rendezvous()
356            state = json.loads(active_version.value)
357            logger.info("New rendezvous state created: %s", state)
358        except etcd.EtcdAlreadyExist:
359            active_version, state = self.get_rdzv_state()
360            # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
361            # but this is ok for us - just means we'll restart from beginning.
362            logger.info("Observed existing rendezvous state: %s", state)
363
364        if state["status"] == "closed":
365            raise RendezvousClosedError
366
367        if state["status"] == "joinable":
368            return self.join_phase(state["version"])
369
370        if state["status"] == "final":
371            self.handle_existing_rendezvous(state["version"])
372            raise EtcdRendezvousRetryImmediately
373
374        self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
375        raise EtcdRendezvousRetryableFailure
376
377    def join_phase(self, expected_version):
378        """
379        We observed a rendezvous state in 'joinable' state, and attempt to join this
380        particular version, and then wait for all other peers to join.
381        """
382        # Failure to join will propagate an exception, causing a re-entry.
383        active_version, this_rank = self.join_rendezvous(expected_version)
384        state = json.loads(active_version.value)
385        logger.info(
386            "Joined rendezvous version %s as rank %s. Full state: %s",
387            state["version"],
388            this_rank,
389            state,
390        )
391
392        # If this worker was first to reach num_min_workers requirement,
393        # and rendezvous is still joinable (therefore it is elastic),
394        # then this worker will be responsible for waiting out the "last call"
395        # timeout and closing (i.e. transitioning to 'frozen') the rendezvous
396        # afterwards.
397        # As a safety against a potential failure of this worker (during the
398        # last call timeout), the rendezvous state is made ephemeral
399        # when min_num_workers is reached.
400
401        if this_rank == self._num_min_workers - 1 and state["status"] == "joinable":
402            logger.info("Rank %s is responsible for join last call.", this_rank)
403            last_call_deadline = time.time() + self._last_call_timeout
404            self.handle_join_last_call(expected_version, last_call_deadline)
405            logger.info("Rank %s finished join last call.", this_rank)
406
407        # Wait for rendezvous state to be frozen, which means a fixed set of peers
408        logger.info("Waiting for remaining peers.")
409        active_version = self.wait_for_peers(expected_version)
410        state = json.loads(active_version.value)
411
412        assert (
413            state["version"] == expected_version
414        ), "Logic error: failed to observe version mismatch"
415
416        return self.confirm_phase(expected_version, this_rank)
417
418    def confirm_phase(self, expected_version, this_rank):
419        """
420        Once the rendezvous state transitions from 'joinable' to 'frozen',
421        we have every participant confirm their membership and setup per-member
422        keep-alive TTL keys, and then wait for all other participants to confirm,
423        which would then successfully conclude this rendezvous.
424        """
425        logger.info("All peers arrived. Confirming membership.")
426        self.confirm_membership(expected_version, this_rank)
427
428        logger.info("Waiting for confirmations from all peers.")
429        active_version = self.wait_for_final(expected_version)
430        state = json.loads(active_version.value)
431
432        logger.info(
433            "Rendezvous version %s is complete. Final state: %s",
434            state["version"],
435            state,
436        )
437
438        # Rendezvous version number; our rank in it; world size
439        return state["version"], this_rank, len(state["participants"])
440
441    def handle_existing_rendezvous(self, expected_version):
442        """
443        Handle the case when there's an existing (state 'final) rendezvous already
444        in place, and we have to announce ourselves waiting, and wait until
445        the next rendezvous opportunity.
446        """
447        # If state is 'final' -> increment num_workers_waiting
448        # Then, observe state changes:
449        #   1. if it's no longer final -> bail out and re-try
450        #   2. if keep alives are missing, destroy it and bail out.
451        active_state = self.announce_self_waiting(expected_version)
452        logger.info(
453            "Added self to waiting list. Rendezvous full state: %s", active_state.value
454        )
455
456        self.wait_for_rendezvous_to_free(expected_version)
457        logger.info(
458            "Previously existing rendezvous state changed. Will re-try joining."
459        )
460
461    def try_create_rendezvous(self):
462        """
463        Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists).
464
465        Raises:
466             RendezvousError - on unexpected state
467        """
468        # Initially active_version is ephemeral - this is to handle the
469        # possibility that might fail to complete the setup transaction,
470        # i.e. the transition "setup" -> "joinable".
471        active_version = self.client.write(
472            key=self.get_path("/rdzv/active_version"),
473            value=json.dumps({"status": "setup"}),
474            prevExist=False,
475            ttl=CONST_ETCD_SETUP_TTL,
476        )
477
478        try:
479            version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
480            version_counter.value = str(int(version_counter.value) + 1)
481            self.client.update(version_counter)
482        except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
483            raise RendezvousError(
484                "Unexpected state of EtcdRendezvousHandler, worker needs to die."
485            ) from e
486
487        # Any failure below results in declaring a retryable rendezvous failure.
488        # The ephemeral /rdzv/active_version will expire and someone can then
489        # re-try the setup process.
490
491        # Create directory node for participant data
492        self.client.write(
493            key=self.get_path(f"/rdzv/v_{version_counter.value}"),
494            value=None,
495            dir=True,
496            prevExist=False,
497        )
498
499        # Publish rendezvous version and signal it is ready-to-be-joined.
500        # If rendezvous was set closed just before this, a retry will happen,
501        # where the closed condition will be handled.
502        return self.client.test_and_set(
503            key=self.get_path("/rdzv/active_version"),
504            value=json.dumps(
505                {
506                    "status": "joinable",
507                    "version": version_counter.value,
508                    "participants": [],
509                }
510            ),
511            prev_value=active_version.value,
512        )
513
514    def join_rendezvous(self, expected_version):
515        """Helper method for the join phase."""
516        # Use compare-and-swap to add self to rendezvous state:
517        while True:
518            cas_delay()
519            active_version, state = self.get_rdzv_state()
520
521            if state["status"] != "joinable":
522                raise EtcdRendezvousRetryableFailure(
523                    "Rendezvous state became non-joinable before we could join. "
524                    "Must join next one."
525                )
526
527            if state["version"] != expected_version:
528                raise EtcdRendezvousRetryImmediately(
529                    "Rendezvous version changed. Must try join the new one."
530                )
531
532            assert (
533                len(state["participants"]) < self._num_max_workers
534            ), "Logic error: joinable rendezvous should always have space left"
535
536            this_rank = len(state["participants"])
537            state["participants"].append(this_rank)
538
539            # When reaching min workers, or changing state to frozen, we'll set
540            # the active_version node to be ephemeral.
541            set_ttl: Optional[int] = None
542            if len(state["participants"]) == self._num_max_workers:
543                state["status"] = "frozen"
544                state["keep_alives"] = []
545                set_ttl = CONST_ETCD_FROZEN_TTL
546            elif len(state["participants"]) >= self._num_min_workers:
547                set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL
548
549            try:
550                # Compare-and-swap.
551                active_version = self.client.test_and_set(
552                    key=self.get_path("/rdzv/active_version"),
553                    value=json.dumps(state),
554                    prev_value=active_version.value,
555                    ttl=set_ttl,
556                )
557                # We succeeded joining.
558                return active_version, this_rank
559
560            except etcd.EtcdCompareFailed:
561                logger.info("Join rendezvous CAS unsuccessful, retrying")
562
563    def wait_for_peers(self, expected_version):
564        """Helper method for the join phase."""
565        active_version, state = self.get_rdzv_state()
566        while True:
567            if state["status"] == "frozen" and state["version"] == expected_version:
568                # Success, all peers arrived.
569                return active_version
570
571            elif state["status"] == "joinable" and state["version"] == expected_version:
572                # Continue waiting for any interesting events.
573                active_version, state = self.try_wait_for_state_change(
574                    etcd_index=active_version.etcd_index + 1
575                )
576
577            else:
578                # No valid transition possible at this point
579                raise EtcdRendezvousRetryableFailure(
580                    "Rendezvous state transition no longer possible. Must re-enter."
581                )
582
583    def confirm_membership(self, expected_version, this_rank):
584        """Helper method for the confirm phase."""
585        # Compare-and-swap loop
586        while True:
587            cas_delay()
588            active_version, state = self.get_rdzv_state()
589
590            if state["status"] != "frozen":
591                raise EtcdRendezvousRetryImmediately(
592                    "Rendezvous no longer frozen, before we confirmed. "
593                    "Must join next one"
594                )
595            if state["version"] != expected_version:
596                raise EtcdRendezvousRetryImmediately(
597                    "Rendezvous version changed. Must try join the new one."
598                )
599
600            this_lease_key = self.get_path(
601                f"/rdzv/v_{expected_version}/rank_{this_rank}"
602            )
603            self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL)
604
605            state["keep_alives"].append(this_lease_key)
606            if len(state["keep_alives"]) == len(state["participants"]):
607                # Everyone confirmed (this rank is last to do so)
608                state["status"] = "final"
609                state["num_workers_waiting"] = 0
610                finalize = True
611            else:
612                finalize = False
613
614            try:
615                # Compare-and-swap. If new state is still frozen, keep it ephemeral.
616                active_version = self.client.test_and_set(
617                    key=self.get_path("/rdzv/active_version"),
618                    value=json.dumps(state),
619                    prev_value=active_version.value,
620                    ttl=None if finalize else CONST_ETCD_FROZEN_TTL,
621                )
622
623                self._lease_this_rank_stop = self.setup_lease_renewal(
624                    this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL
625                )
626                return active_version
627
628            except etcd.EtcdCompareFailed:
629                logger.info("Confirm membership CAS unsuccessful, retrying")
630
631    def wait_for_final(self, expected_version):
632        """Helper method for the confirm phase."""
633        active_version, state = self.get_rdzv_state()
634        while True:
635            if state["status"] == "final" and state["version"] == expected_version:
636                # Success. This rendezvous is final, and we accept it.
637                return active_version
638
639            elif state["status"] == "frozen" and state["version"] == expected_version:
640                # Continue waiting for any interesting events.
641                active_version, state = self.try_wait_for_state_change(
642                    etcd_index=active_version.etcd_index + 1
643                )
644
645            else:
646                # No valid transition possible at this point
647                raise EtcdRendezvousRetryableFailure(
648                    "Rendezvous state transition no longer possible. Must re-enter."
649                )
650
651    def announce_self_waiting(self, expected_version):
652        """
653        Announce this worker is waiting (via num_workers_waiting counter) to join next
654        rendezvous, but only if state and version match.
655        """
656        while True:
657            cas_delay()
658            active_version, state = self.get_rdzv_state()
659
660            if state["status"] != "final" or state["version"] != expected_version:
661                raise EtcdRendezvousRetryImmediately
662
663            # Increment counter to signal an additional waiting worker.
664            state["num_workers_waiting"] += 1
665
666            try:
667                active_version = self.client.test_and_set(
668                    key=self.get_path("/rdzv/active_version"),
669                    value=json.dumps(state),
670                    prev_value=active_version.value,
671                )
672                return active_version
673
674            except etcd.EtcdCompareFailed:
675                logger.info("Announce self as waiting CAS unsuccessful, retrying")
676
677    def wait_for_rendezvous_to_free(self, expected_version):
678        """
679        When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join.
680
681        Such opportunity may come from:
682
683        1. rendezvous state changed by someone else, in which case we unblock and retry.
684        2. rendezvous becomes invalid because at least one member failed to renew their
685           leased keep_alive node. We detect this, and destroy the rendezvous.
686        """
687        active_version, state = self.get_rdzv_state()
688        while True:
689            if state["status"] != "final" or state["version"] != expected_version:
690                return
691
692            # Check if current rendezvous state is valid, in the sense that all
693            # its members are alive (renewing their lease).
694            # If not, try destroy this rendezvous, so a new one can be created.
695            alive_members = self.client.get(
696                self.get_path(f"/rdzv/v_{expected_version}")
697            )
698            keep_alive_keys = [ch.key for ch in alive_members.children]
699
700            for key in state["keep_alives"]:
701                if key not in keep_alive_keys:
702                    # This participant didn't renew their lease. We'll declare this
703                    # rendezvous version as dead (but only if it hadn't changed)
704                    logger.info("Keep-alive key %s is not renewed.", key)
705                    logger.info(
706                        "Rendezvous version %s is incomplete. ", expected_version
707                    )
708                    logger.info("Attempting to destroy it.")
709
710                    # Compare-and-delete operation. Throws if compare failed,
711                    # which means rendezvous was already destroyed/re-created/closed,
712                    # and we can try to re-enter the barrier.
713                    self.client.delete(
714                        key=self.get_path("/rdzv/active_version"),
715                        prevValue=active_version.value,
716                    )
717
718                    logger.info(
719                        "Destroyed rendezvous version %s successfully.",
720                        expected_version,
721                    )
722
723                    # We can return (and retry) immediately
724                    return
725
726            # Existing rendezvous seems valid, no reason to destroy it.
727            # We just have to wait until something changes and re-check.
728            try:
729                overall_timeout = (
730                    max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
731                )
732                self.client.watch(
733                    key=self.get_path("/rdzv"),
734                    index=active_version.etcd_index + 1,
735                    recursive=True,
736                    timeout=overall_timeout,
737                )
738            except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
739                pass
740
741            if time.time() > self._rendezvous_deadline:
742                raise RendezvousTimeoutError
743            active_version, state = self.get_rdzv_state()
744
745    def handle_join_last_call(self, expected_version, deadline):
746        """
747        After we reach min number of workers, one particular worker takes on the
748        responsibility of waiting an additional timeout before closing the join window.
749        If the worker responsible for this fails, the rendezvous will be destroyed due
750        to expiring TTL, and the other participants will re-rendezvous.
751
752        Here we expect to see state <joinable, expected_version>
753        Exit gracefully if either:
754
755        1. state becomes <frozen, expected_version>
756        2. timeout happens (reaching deadline), in which case
757           we try the transition to <frozen, expected_version>
758
759        Exit with exception otherwise.
760        """
761        active_version, state = self.get_rdzv_state()
762        while True:
763            if state["status"] == "frozen" and state["version"] == expected_version:
764                # Worker set became frozen before last-call timeout. This is possible
765                # when num_max_workers is reached before the timeout.
766                return
767
768            if state["status"] != "joinable" or state["version"] != expected_version:
769                raise EtcdRendezvousRetryableFailure(
770                    "Rendezvous state transition no longer possible. Must re-enter."
771                )
772
773            # If timeout occurred, attempt a state transition (joinable -> frozen)
774            if time.time() >= deadline:
775                state["status"] = "frozen"
776                state["keep_alives"] = []
777                try:
778                    active_version = self.client.test_and_set(
779                        key=self.get_path("/rdzv/active_version"),
780                        value=json.dumps(state),
781                        prev_value=active_version.value,
782                        ttl=CONST_ETCD_FROZEN_TTL,
783                    )
784                    # We successfully made this rendezvous frozen.
785                    return
786                except etcd.EtcdCompareFailed:
787                    logger.info(
788                        "Join last-call transition CAS unsuccessful. Will retry"
789                    )
790                    cas_delay()
791                    active_version, state = self.get_rdzv_state()
792                    continue
793
794            # Timeout did not occur, so we must refresh TTL, and wait for
795            # further changes. Note: we only want TTL to be refreshed if
796            # state is still joinable, hence we use CAS for that here,
797            # even though we don't change any of the data.
798            try:
799                active_version = self.client.test_and_set(
800                    key=self.get_path("/rdzv/active_version"),
801                    value=active_version.value,
802                    prev_value=active_version.value,
803                    ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL,
804                )
805
806                # Minimize "oversleeping":
807                timeout = min(
808                    CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2,
809                    deadline - time.time() + 1.0,  # Oversleeping by 1s is ok.
810                )
811                active_version, state = self.try_wait_for_state_change(
812                    etcd_index=active_version.etcd_index + 1, timeout=timeout
813                )
814            except etcd.EtcdCompareFailed:
815                logger.info("Join last-call TTL refresh CAS unsuccessful, will retry")
816                cas_delay()
817                active_version, state = self.get_rdzv_state()
818
819    def set_closed(self):
820        """
821        Mark rendezvous 'closed' for current run_id, which is used to signal other
822        participants to not attempt to perform (re-)rendezvous. This is useful
823        when one of the workers decides the job is complete.
824        """
825        while True:
826            active_version, state = self.get_rdzv_state()
827
828            if state["status"] == "closed":
829                # Already closed by someone else.
830                return
831
832            state["status"] = "closed"
833            try:
834                self.client.test_and_set(
835                    key=self.get_path("/rdzv/active_version"),
836                    value=json.dumps(state),
837                    prev_value=active_version.value,
838                )
839                return
840
841            except etcd.EtcdCompareFailed:
842                logger.info("Set closed CAS unsuccessful, retrying")
843                cas_delay()
844
845    def get_rdzv_state(self):
846        active_version = self.client.get(key=self.get_path("/rdzv/active_version"))
847        return active_version, json.loads(active_version.value)
848
849    def try_wait_for_state_change(self, etcd_index, timeout=None):
850        # Don't sleep past the overall deadline (at least more than by 1s)
851        overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
852        timeout = overall_timeout if timeout is None else min(timeout, overall_timeout)
853
854        try:
855            self.client.watch(
856                self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout
857            )
858        except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
859            pass
860
861        if time.time() > self._rendezvous_deadline:
862            raise RendezvousTimeoutError
863
864        # Unfortunately, we have to do another fetch in order to get last etcd_index.
865        return self.get_rdzv_state()
866
867    def get_path(self, path):
868        if not path.startswith("/"):
869            path = "/" + path
870
871        return f"{self._prefix}run_{self._run_id}{path}"
872
873    def create_path_if_not_exists(self, full_path, ttl=None):
874        try:
875            self.client.write(
876                key=full_path, value=None, dir=True, prevExist=False, ttl=ttl
877            )
878        except etcd.EtcdAlreadyExist:
879            pass
880
881    def setup_lease_renewal(self, full_path, ttl):
882        # NOTE: For ephemeral key TTL renewal (~lease) to work correctly,
883        # make sure you don't call any long-blocking methods that do not
884        # release the Python's GIL! An example of this is calling a pybind11
885        # extension function that is blocking / long-running, but is not
886        # doing a scoped release of the GIL.
887        def lease_worker(client, path, ttl, stop_event):
888            while True:
889                try:
890                    client.refresh(path, ttl=ttl)
891                except etcd.EtcdKeyNotFound:
892                    break
893                except ConnectionRefusedError:
894                    # This error usually occurs during test when the server already got terminated but the
895                    # python garbage collector have not yet invoked the __del__ method.
896                    break
897
898                if stop_event.wait(timeout=ttl / 2):
899                    break
900
901        lease_stop_event = threading.Event()
902        lease_thread = threading.Thread(
903            target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event)
904        )
905
906        lease_thread.daemon = True
907        lease_thread.start()
908
909        return lease_stop_event
910
911    def store_extra_data(self, rdzv_version, key, value):
912        node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
913        try:
914            # If first time we are storing anything:
915            extra_data = self.client.write(
916                key=node, value=json.dumps({key: value}), prevExist=False
917            )
918            return
919        except etcd.EtcdAlreadyExist:
920            pass
921
922        # CAS loop, to make sure we don't lose concurrent stores.
923        while True:
924            # We never delete extra_data. Failure here should be fatal, no special handling.
925            extra_data = self.client.get(node)
926
927            new_extra_data_value = json.loads(extra_data.value)
928            new_extra_data_value[key] = value
929
930            try:
931                extra_data = self.client.test_and_set(
932                    key=node,
933                    value=json.dumps(new_extra_data_value),
934                    prev_value=extra_data.value,
935                )
936                return
937            except etcd.EtcdCompareFailed:
938                logger.info("Store extra_data CAS unsuccessful, retrying")
939                time.sleep(0.1)
940
941    def load_extra_data(self, rdzv_version, key, timeout=None):
942        # 'extra_data' node itself, and the directory it is located in:
943        node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
944        node_dir = self.get_path(f"/rdzv/v_{rdzv_version}")
945
946        # TODO: implement timeout
947        # https://github.com/pytorch/elastic/issues/12
948        while True:
949            # Combined wait for the node itself, and the key inside it.
950            root = self.client.get(node_dir)
951
952            # Find the extra_data node, if it exists
953            extra_data = [n for n in root.children if n.key == node]
954            assert len(extra_data) <= 1
955
956            # Node for extra_data exists, check the desired key inside it.
957            if len(extra_data) == 1:
958                extra_data_dict = json.loads(extra_data[0].value)
959                if key in extra_data_dict:
960                    return extra_data_dict[key]
961
962            # The 'extra_data' node doesn't exist, or they key isn't published yet.
963            # Wait for interesting events on the extra_data node and retry.
964            try:
965                self.client.watch(node, index=root.etcd_index + 1)
966            except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
967                pass
968
969    def setup_kv_store(self, rdzv_version):
970        store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv")
971        self.create_path_if_not_exists(store_path)
972        return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path)
973
974
975def _create_etcd_client(params: RendezvousParameters) -> etcd.Client:
976    """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``."""
977    hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379)
978
979    # The communication protocol
980    protocol = params.config.get("protocol")
981    if protocol is None:
982        protocol = "http"
983    else:
984        if protocol != "http" and protocol != "https":
985            raise ValueError("The etcd protocol must be HTTP or HTTPS.")
986
987    # The SSL client certificate
988    ssl_cert = params.config.get("cert")
989    if ssl_cert is not None:
990        cert_key = params.config.get("key")
991        if cert_key is not None:
992            # The etcd client expects the certificate key as the second element
993            # of the `cert` tuple.
994            ssl_cert = (ssl_cert, cert_key)
995
996    # The root certificate
997    ca_cert = params.config.get("cacert")
998
999    return etcd.Client(
1000        hostname,
1001        port,
1002        protocol=protocol,
1003        cert=ssl_cert,
1004        ca_cert=ca_cert,
1005        allow_reconnect=True,
1006    )
1007
1008
1009# Handler for torch.distributed "static" registration
1010def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
1011    """
1012    Usage:
1013
1014    ::
1015
1016    rdzv_params = RendezvousParameters(
1017                        backend="etcd",
1018                        endpoint="192.168.0.42:2379",
1019                        run_id="123",
1020                        min_nodes=4,
1021                        max_nodes=8,
1022                        timeout=300,
1023                        last_call_timeout=30,
1024                        etcd_prefix="custom_prefix",
1025                        protocol="https",
1026                        cacert="/etc/kubernetes/certs/ca.crt",
1027                        cert="/etc/kubernetes/certs/client.crt",
1028                        key="/etc/kubernetes/certs/client.key")
1029    # -- or --
1030    rdzv_params = RendezvousParameters(
1031                        backend="etcd",
1032                        endpoint="192.168.0.42:2379",
1033                        run_id="123",
1034                        min_nodes=4,
1035                        max_nodes=8)
1036
1037    etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params)
1038
1039
1040    Where:
1041        run_id - unique id for this training job instance,
1042        min_nodes - min number of workers expected to join the rendezvous,
1043        max_nodes - max number of workers allowed to join the rendezvous,
1044                        defaults to min_workers is not specified.
1045        timeout - total timeout within which next_rendezvous is expected to
1046                      succeed; a RendezvousTimeoutError is raised otherwise;
1047                      Defaults is 600 (10 minutes).
1048        last_call_timeout - additional wait amount ("last call") after
1049                            min number of workers has been reached.
1050                            Defaults to 30 seconds.
1051        etcd_prefix - path prefix (from etcd root), inside which all
1052                      etcd nodes will be created.
1053                      Default is "/torchelastic/p2p".
1054        protocol - http (default) or https to access etcd.
1055        cacert - CA cert to access etcd, only makes sense with https.
1056        cert - client cert to access etcd, only makes sense with https.
1057        key - client key to access etcd, only makes sense with https.
1058    """
1059    client = _create_etcd_client(params)
1060
1061    etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p")
1062
1063    rdzv = EtcdRendezvous(
1064        client=client,
1065        prefix=etcd_prefix,
1066        run_id=params.run_id,
1067        num_min_workers=params.min_nodes,
1068        num_max_workers=params.max_nodes,
1069        timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
1070        last_call_timeout=params.get_as_int(
1071            "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT
1072        ),
1073    )
1074    return EtcdRendezvousHandler(
1075        rdzv_impl=rdzv,
1076        local_addr=params.local_addr,
1077    )
1078