1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import json 11import logging 12import sys 13import threading 14import time 15from typing import Optional 16 17import etcd # type: ignore[import] 18 19from torch.distributed.elastic.rendezvous import ( 20 RendezvousClosedError, 21 RendezvousError, 22 RendezvousHandler, 23 RendezvousInfo, 24 RendezvousParameters, 25 RendezvousStoreInfo, 26 RendezvousTimeoutError, 27) 28 29from .etcd_store import cas_delay, EtcdStore 30from .utils import parse_rendezvous_endpoint 31 32 33__all__ = [ 34 "EtcdRendezvousRetryableFailure", 35 "EtcdRendezvousRetryImmediately", 36 "EtcdRendezvousHandler", 37 "EtcdRendezvous", 38 "create_rdzv_handler", 39] 40 41_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s") 42_log_handler = logging.StreamHandler(sys.stderr) 43_log_handler.setFormatter(_log_fmt) 44 45logger = logging.getLogger(__name__) 46logger.propagate = False 47logger.setLevel(logging.INFO) 48logger.addHandler(_log_handler) 49 50 51# Retryable failure exception means the we were too late to make 52# a desired state transition (e.g. because of a race condition), 53# and should now restart from the beginning. 54# A small delay is recommended to avoid spamming Etcd. 55class EtcdRendezvousRetryableFailure(Exception): 56 pass 57 58 59# Similar to retryable failure, but the new state we observed suggests we 60# can re-try immediately, i.e. without a need for "safety delay". 61class EtcdRendezvousRetryImmediately(Exception): 62 pass 63 64 65# Default timeout for the rendezvous. 66_DEFAULT_TIMEOUT: int = 600 # 10 minutes 67 68# Additional waiting time after reaching the minimum number of nodes 69# in case the rendezvous is elastic (min != max). 70_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds 71 72# Various constants used internally in EtcdRendezvous 73CONST_ETCD_SETUP_TTL = 5 74CONST_ETCD_FROZEN_TTL = 10 75CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10 76 77# Ephemeral node TTL for worker's keep-alive key: 78CONST_WORKER_KEEPALIVE_TTL = 10 79 80# TTL for the ephemeral run_id-specific directory. All rendezvous state data 81# for a specific run_id (job instance) is contained within directory. 82# Its only role is to clean-up rendezvous data from old runs (for the case when 83# etcd server is persistent), and has no affect on correctness, but should be 84# larger than any timeouts that a worker process is expected to survive: 85CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours 86 87 88class EtcdRendezvousHandler(RendezvousHandler): 89 """ 90 Implements a 91 :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` interface 92 backed by 93 :py:class:`torch.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`. 94 ``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to 95 use and to pass implementation specific configurations to the rendezvous 96 module. The basic etcd rendezvous configuration URL looks like the following 97 :: 98 99 etcd://<etcd_address>:<port>/<job_id>?min_workers=<min_workers>&max_workers=<max_workers> # noqa: W605 100 101 -- example -- 102 103 etcd://localhost:2379/1234?min_workers=1&max_workers=3 104 105 The URL above is interpreted as follows: 106 107 1. Use the rendezvous handler that is registered with the ``etcd`` 108 scheme 109 2. The ``etcd`` endpoint to use is ``localhost:2379`` 110 3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to 111 share a common etcd server for multiple jobs so long as the 112 ``job_ids`` are guaranteed to be unique). Note that the job id can be 113 any string (e.g. does not need to be a number) as long as it is 114 unique. 115 4. ``min_workers=1`` and ``max_workers=3`` specifies a range for 116 membership size - Torch Distributed Elastic starts running the job as 117 long as the cluster size is greater than or equal to ``min_workers`` 118 and admits up to ``max_workers`` into the cluster. 119 120 Below are a full list of the parameters that can be passed to etcd 121 rendezvous: 122 123 +--------------------------------------------+--------------------------+ 124 | Parameter | Description | 125 +============================================+==========================+ 126 | min_workers | minimum number of | 127 | | workers for the | 128 | | rendezvous to be valid | 129 +--------------------------------------------+--------------------------+ 130 | max_workers | maximum number of | 131 | | workers to admit | 132 +--------------------------------------------+--------------------------+ 133 | timeout | total timeout within | 134 | | which next_rendezvous is | 135 | | expected to succeed | 136 | | (default 600s) | 137 +--------------------------------------------+--------------------------+ 138 | last_call_timeout | additional wait amount | 139 | | ("last call") after min | 140 | | number of workers has | 141 | | been reached (defaults | 142 | | to 30s) | 143 +--------------------------------------------+--------------------------+ 144 | etcd_prefix | path prefix (from etcd | 145 | | root), inside which all | 146 | | etcd nodes will be | 147 | | created (defaults to | 148 | | ``/torchelastic/p2p``) | 149 +--------------------------------------------+--------------------------+ 150 """ 151 152 def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]): 153 """ 154 Args: 155 rdzv_impl: the implementation of the rendezvous 156 local_addr: the local address of the current node 157 """ 158 159 self._rdzv_impl = rdzv_impl 160 self._local_addr = local_addr 161 162 def __del__(self): 163 # TODO: look into using weakref here instead. 164 del self._rdzv_impl 165 166 def get_backend(self) -> str: 167 return "etcd" 168 169 def next_rendezvous(self): 170 rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier() 171 172 logger.info("Creating EtcdStore as the c10d::Store implementation") 173 store = self._rdzv_impl.setup_kv_store(rdzv_version) 174 175 bootstrap_store_info = RendezvousStoreInfo.build( 176 rank, store, local_addr=self._local_addr 177 ) 178 return RendezvousInfo(store, rank, world_size, bootstrap_store_info) 179 180 def is_closed(self): 181 try: 182 _, state = self._rdzv_impl.get_rdzv_state() 183 return state["status"] == "closed" 184 except etcd.EtcdKeyNotFound: 185 # No rendezvous state, so it cannot be closed. 186 return False 187 188 def set_closed(self): 189 self._rdzv_impl.set_closed() 190 191 def num_nodes_waiting(self): 192 try: 193 _, state = self._rdzv_impl.get_rdzv_state() 194 if state["status"] == "final": 195 return state["num_workers_waiting"] 196 except etcd.EtcdKeyNotFound: 197 pass 198 return 0 199 200 def get_run_id(self) -> str: 201 return self._rdzv_impl._run_id 202 203 def shutdown(self) -> bool: 204 try: 205 self.set_closed() 206 return True 207 except BaseException as e: 208 logger.warning("Shutdown failed. Error occurred: %s", str(e)) 209 return False 210 211 212# TODO: we should probably handle a few additional errors, 213# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are 214# only relevant for multi-node Etcd ensemble. A simple retry would work, 215# but is verbose to add everywhere. Consider wrapping the client calls 216# into auto-retry for these errors? 217# 218class EtcdRendezvous: 219 """A rendezvous implementation that uses `etcd <https://etcd.io/>`__ as the backend store.""" 220 221 def __init__( 222 self, 223 client, 224 prefix, 225 run_id, 226 num_min_workers, 227 num_max_workers, 228 timeout, 229 last_call_timeout, 230 ): 231 self.client = client 232 logger.info("Etcd machines: %s", self.client.machines) 233 234 self._prefix = prefix 235 self._run_id = run_id 236 self._num_min_workers = num_min_workers 237 self._num_max_workers = num_max_workers 238 self._timeout = timeout 239 self._last_call_timeout = last_call_timeout 240 241 # For cleaning up TTL refresher threads (for ephemeral keys) 242 self._lease_run_id_stop = None 243 self._lease_this_rank_stop = None 244 245 if not self._prefix.endswith("/"): 246 self._prefix += "/" 247 248 # Setup a permanent prefix dir, if didn't exist 249 if self._prefix != "/": 250 self.create_path_if_not_exists(self._prefix) 251 252 # Lease a "sub-root" node specific to this job instance (run_id) 253 self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL) 254 self._lease_run_id_stop = self.setup_lease_renewal( 255 self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL 256 ) 257 258 # Subdir for all rendezvous work 259 self.create_path_if_not_exists(self.get_path("/rdzv")) 260 261 # Create a rendezvous version counter, if doesn't exist 262 try: 263 self.client.write( 264 key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False 265 ) 266 except etcd.EtcdAlreadyExist: 267 pass 268 269 def __del__(self): 270 # TODO: look into using weakref here instead. 271 if self._lease_run_id_stop is not None: 272 self._lease_run_id_stop.set() 273 274 if self._lease_this_rank_stop is not None: 275 self._lease_this_rank_stop.set() 276 277 def rendezvous_barrier(self): 278 """ 279 Main entry point for next rendezvous. 280 281 This method is blocking until rendezvous succeeds or a timeout occurs. 282 283 Returns: 284 ``(rdzv_version, rank, world_size)`` 285 286 Raises: 287 RendezvousTimeoutError - timeout waiting for rendezvous 288 RendezvousClosedError - rendezvous is or was closed while waiting 289 RendezvousError - other persistent errors that 290 render the rendezvous non-retryable 291 """ 292 self._rendezvous_deadline = time.time() + self._timeout 293 while True: 294 if time.time() > self._rendezvous_deadline: 295 raise RendezvousTimeoutError 296 297 logger.info("Attempting to join next rendezvous") 298 try: 299 # Dis-own our lease in the previous rendezvous, if exists 300 if self._lease_this_rank_stop is not None: 301 self._lease_this_rank_stop.set() 302 303 return self.init_phase() 304 305 except EtcdRendezvousRetryImmediately: 306 # The type of failure suggests we can retry without delay 307 pass 308 309 except EtcdRendezvousRetryableFailure: 310 # In case of retryable failure, wait a small delay 311 # to avoid spamming etcd 312 time.sleep(1) 313 314 except RendezvousTimeoutError: 315 logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler") 316 raise 317 318 except RendezvousClosedError: 319 logger.info( 320 "Rendezvous for run_id=%s was observed to be closed", self._run_id 321 ) 322 raise 323 324 except RendezvousError: 325 raise 326 327 except Exception as e: 328 # In case of a general exception, wait a small delay 329 # to avoid spamming etcd 330 # FIXME: there are a few things that fall under this like 331 # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. 332 logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) 333 time.sleep(1) 334 335 def init_phase(self): 336 """ 337 Initially, the rendezvous state is expected to be one of: 338 339 1. empty (non-existent) - in this case we try to create a new one. 340 2. joinable - we try to join it. 341 3. final - we announce ourselves as waiting, and go into monitoring mode 342 343 Any other state is considered transitional, and will be retried after 344 a short delay. 345 346 Returns: 347 ``(rdzv_version, rank, world_size)`` 348 349 Raises: 350 RendezvousClosedError - current rendezvous was/is closed 351 EtcdRendezvousRetryableFailure - observed some intermediate 352 state, which is best handled by retrying later 353 """ 354 try: 355 active_version = self.try_create_rendezvous() 356 state = json.loads(active_version.value) 357 logger.info("New rendezvous state created: %s", state) 358 except etcd.EtcdAlreadyExist: 359 active_version, state = self.get_rdzv_state() 360 # Note: it is possible for above query to fail (etcd.EtcdKeyNotFound), 361 # but this is ok for us - just means we'll restart from beginning. 362 logger.info("Observed existing rendezvous state: %s", state) 363 364 if state["status"] == "closed": 365 raise RendezvousClosedError 366 367 if state["status"] == "joinable": 368 return self.join_phase(state["version"]) 369 370 if state["status"] == "final": 371 self.handle_existing_rendezvous(state["version"]) 372 raise EtcdRendezvousRetryImmediately 373 374 self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1) 375 raise EtcdRendezvousRetryableFailure 376 377 def join_phase(self, expected_version): 378 """ 379 We observed a rendezvous state in 'joinable' state, and attempt to join this 380 particular version, and then wait for all other peers to join. 381 """ 382 # Failure to join will propagate an exception, causing a re-entry. 383 active_version, this_rank = self.join_rendezvous(expected_version) 384 state = json.loads(active_version.value) 385 logger.info( 386 "Joined rendezvous version %s as rank %s. Full state: %s", 387 state["version"], 388 this_rank, 389 state, 390 ) 391 392 # If this worker was first to reach num_min_workers requirement, 393 # and rendezvous is still joinable (therefore it is elastic), 394 # then this worker will be responsible for waiting out the "last call" 395 # timeout and closing (i.e. transitioning to 'frozen') the rendezvous 396 # afterwards. 397 # As a safety against a potential failure of this worker (during the 398 # last call timeout), the rendezvous state is made ephemeral 399 # when min_num_workers is reached. 400 401 if this_rank == self._num_min_workers - 1 and state["status"] == "joinable": 402 logger.info("Rank %s is responsible for join last call.", this_rank) 403 last_call_deadline = time.time() + self._last_call_timeout 404 self.handle_join_last_call(expected_version, last_call_deadline) 405 logger.info("Rank %s finished join last call.", this_rank) 406 407 # Wait for rendezvous state to be frozen, which means a fixed set of peers 408 logger.info("Waiting for remaining peers.") 409 active_version = self.wait_for_peers(expected_version) 410 state = json.loads(active_version.value) 411 412 assert ( 413 state["version"] == expected_version 414 ), "Logic error: failed to observe version mismatch" 415 416 return self.confirm_phase(expected_version, this_rank) 417 418 def confirm_phase(self, expected_version, this_rank): 419 """ 420 Once the rendezvous state transitions from 'joinable' to 'frozen', 421 we have every participant confirm their membership and setup per-member 422 keep-alive TTL keys, and then wait for all other participants to confirm, 423 which would then successfully conclude this rendezvous. 424 """ 425 logger.info("All peers arrived. Confirming membership.") 426 self.confirm_membership(expected_version, this_rank) 427 428 logger.info("Waiting for confirmations from all peers.") 429 active_version = self.wait_for_final(expected_version) 430 state = json.loads(active_version.value) 431 432 logger.info( 433 "Rendezvous version %s is complete. Final state: %s", 434 state["version"], 435 state, 436 ) 437 438 # Rendezvous version number; our rank in it; world size 439 return state["version"], this_rank, len(state["participants"]) 440 441 def handle_existing_rendezvous(self, expected_version): 442 """ 443 Handle the case when there's an existing (state 'final) rendezvous already 444 in place, and we have to announce ourselves waiting, and wait until 445 the next rendezvous opportunity. 446 """ 447 # If state is 'final' -> increment num_workers_waiting 448 # Then, observe state changes: 449 # 1. if it's no longer final -> bail out and re-try 450 # 2. if keep alives are missing, destroy it and bail out. 451 active_state = self.announce_self_waiting(expected_version) 452 logger.info( 453 "Added self to waiting list. Rendezvous full state: %s", active_state.value 454 ) 455 456 self.wait_for_rendezvous_to_free(expected_version) 457 logger.info( 458 "Previously existing rendezvous state changed. Will re-try joining." 459 ) 460 461 def try_create_rendezvous(self): 462 """ 463 Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists). 464 465 Raises: 466 RendezvousError - on unexpected state 467 """ 468 # Initially active_version is ephemeral - this is to handle the 469 # possibility that might fail to complete the setup transaction, 470 # i.e. the transition "setup" -> "joinable". 471 active_version = self.client.write( 472 key=self.get_path("/rdzv/active_version"), 473 value=json.dumps({"status": "setup"}), 474 prevExist=False, 475 ttl=CONST_ETCD_SETUP_TTL, 476 ) 477 478 try: 479 version_counter = self.client.get(self.get_path("/rdzv/version_counter")) 480 version_counter.value = str(int(version_counter.value) + 1) 481 self.client.update(version_counter) 482 except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e: 483 raise RendezvousError( 484 "Unexpected state of EtcdRendezvousHandler, worker needs to die." 485 ) from e 486 487 # Any failure below results in declaring a retryable rendezvous failure. 488 # The ephemeral /rdzv/active_version will expire and someone can then 489 # re-try the setup process. 490 491 # Create directory node for participant data 492 self.client.write( 493 key=self.get_path(f"/rdzv/v_{version_counter.value}"), 494 value=None, 495 dir=True, 496 prevExist=False, 497 ) 498 499 # Publish rendezvous version and signal it is ready-to-be-joined. 500 # If rendezvous was set closed just before this, a retry will happen, 501 # where the closed condition will be handled. 502 return self.client.test_and_set( 503 key=self.get_path("/rdzv/active_version"), 504 value=json.dumps( 505 { 506 "status": "joinable", 507 "version": version_counter.value, 508 "participants": [], 509 } 510 ), 511 prev_value=active_version.value, 512 ) 513 514 def join_rendezvous(self, expected_version): 515 """Helper method for the join phase.""" 516 # Use compare-and-swap to add self to rendezvous state: 517 while True: 518 cas_delay() 519 active_version, state = self.get_rdzv_state() 520 521 if state["status"] != "joinable": 522 raise EtcdRendezvousRetryableFailure( 523 "Rendezvous state became non-joinable before we could join. " 524 "Must join next one." 525 ) 526 527 if state["version"] != expected_version: 528 raise EtcdRendezvousRetryImmediately( 529 "Rendezvous version changed. Must try join the new one." 530 ) 531 532 assert ( 533 len(state["participants"]) < self._num_max_workers 534 ), "Logic error: joinable rendezvous should always have space left" 535 536 this_rank = len(state["participants"]) 537 state["participants"].append(this_rank) 538 539 # When reaching min workers, or changing state to frozen, we'll set 540 # the active_version node to be ephemeral. 541 set_ttl: Optional[int] = None 542 if len(state["participants"]) == self._num_max_workers: 543 state["status"] = "frozen" 544 state["keep_alives"] = [] 545 set_ttl = CONST_ETCD_FROZEN_TTL 546 elif len(state["participants"]) >= self._num_min_workers: 547 set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL 548 549 try: 550 # Compare-and-swap. 551 active_version = self.client.test_and_set( 552 key=self.get_path("/rdzv/active_version"), 553 value=json.dumps(state), 554 prev_value=active_version.value, 555 ttl=set_ttl, 556 ) 557 # We succeeded joining. 558 return active_version, this_rank 559 560 except etcd.EtcdCompareFailed: 561 logger.info("Join rendezvous CAS unsuccessful, retrying") 562 563 def wait_for_peers(self, expected_version): 564 """Helper method for the join phase.""" 565 active_version, state = self.get_rdzv_state() 566 while True: 567 if state["status"] == "frozen" and state["version"] == expected_version: 568 # Success, all peers arrived. 569 return active_version 570 571 elif state["status"] == "joinable" and state["version"] == expected_version: 572 # Continue waiting for any interesting events. 573 active_version, state = self.try_wait_for_state_change( 574 etcd_index=active_version.etcd_index + 1 575 ) 576 577 else: 578 # No valid transition possible at this point 579 raise EtcdRendezvousRetryableFailure( 580 "Rendezvous state transition no longer possible. Must re-enter." 581 ) 582 583 def confirm_membership(self, expected_version, this_rank): 584 """Helper method for the confirm phase.""" 585 # Compare-and-swap loop 586 while True: 587 cas_delay() 588 active_version, state = self.get_rdzv_state() 589 590 if state["status"] != "frozen": 591 raise EtcdRendezvousRetryImmediately( 592 "Rendezvous no longer frozen, before we confirmed. " 593 "Must join next one" 594 ) 595 if state["version"] != expected_version: 596 raise EtcdRendezvousRetryImmediately( 597 "Rendezvous version changed. Must try join the new one." 598 ) 599 600 this_lease_key = self.get_path( 601 f"/rdzv/v_{expected_version}/rank_{this_rank}" 602 ) 603 self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL) 604 605 state["keep_alives"].append(this_lease_key) 606 if len(state["keep_alives"]) == len(state["participants"]): 607 # Everyone confirmed (this rank is last to do so) 608 state["status"] = "final" 609 state["num_workers_waiting"] = 0 610 finalize = True 611 else: 612 finalize = False 613 614 try: 615 # Compare-and-swap. If new state is still frozen, keep it ephemeral. 616 active_version = self.client.test_and_set( 617 key=self.get_path("/rdzv/active_version"), 618 value=json.dumps(state), 619 prev_value=active_version.value, 620 ttl=None if finalize else CONST_ETCD_FROZEN_TTL, 621 ) 622 623 self._lease_this_rank_stop = self.setup_lease_renewal( 624 this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL 625 ) 626 return active_version 627 628 except etcd.EtcdCompareFailed: 629 logger.info("Confirm membership CAS unsuccessful, retrying") 630 631 def wait_for_final(self, expected_version): 632 """Helper method for the confirm phase.""" 633 active_version, state = self.get_rdzv_state() 634 while True: 635 if state["status"] == "final" and state["version"] == expected_version: 636 # Success. This rendezvous is final, and we accept it. 637 return active_version 638 639 elif state["status"] == "frozen" and state["version"] == expected_version: 640 # Continue waiting for any interesting events. 641 active_version, state = self.try_wait_for_state_change( 642 etcd_index=active_version.etcd_index + 1 643 ) 644 645 else: 646 # No valid transition possible at this point 647 raise EtcdRendezvousRetryableFailure( 648 "Rendezvous state transition no longer possible. Must re-enter." 649 ) 650 651 def announce_self_waiting(self, expected_version): 652 """ 653 Announce this worker is waiting (via num_workers_waiting counter) to join next 654 rendezvous, but only if state and version match. 655 """ 656 while True: 657 cas_delay() 658 active_version, state = self.get_rdzv_state() 659 660 if state["status"] != "final" or state["version"] != expected_version: 661 raise EtcdRendezvousRetryImmediately 662 663 # Increment counter to signal an additional waiting worker. 664 state["num_workers_waiting"] += 1 665 666 try: 667 active_version = self.client.test_and_set( 668 key=self.get_path("/rdzv/active_version"), 669 value=json.dumps(state), 670 prev_value=active_version.value, 671 ) 672 return active_version 673 674 except etcd.EtcdCompareFailed: 675 logger.info("Announce self as waiting CAS unsuccessful, retrying") 676 677 def wait_for_rendezvous_to_free(self, expected_version): 678 """ 679 When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join. 680 681 Such opportunity may come from: 682 683 1. rendezvous state changed by someone else, in which case we unblock and retry. 684 2. rendezvous becomes invalid because at least one member failed to renew their 685 leased keep_alive node. We detect this, and destroy the rendezvous. 686 """ 687 active_version, state = self.get_rdzv_state() 688 while True: 689 if state["status"] != "final" or state["version"] != expected_version: 690 return 691 692 # Check if current rendezvous state is valid, in the sense that all 693 # its members are alive (renewing their lease). 694 # If not, try destroy this rendezvous, so a new one can be created. 695 alive_members = self.client.get( 696 self.get_path(f"/rdzv/v_{expected_version}") 697 ) 698 keep_alive_keys = [ch.key for ch in alive_members.children] 699 700 for key in state["keep_alives"]: 701 if key not in keep_alive_keys: 702 # This participant didn't renew their lease. We'll declare this 703 # rendezvous version as dead (but only if it hadn't changed) 704 logger.info("Keep-alive key %s is not renewed.", key) 705 logger.info( 706 "Rendezvous version %s is incomplete. ", expected_version 707 ) 708 logger.info("Attempting to destroy it.") 709 710 # Compare-and-delete operation. Throws if compare failed, 711 # which means rendezvous was already destroyed/re-created/closed, 712 # and we can try to re-enter the barrier. 713 self.client.delete( 714 key=self.get_path("/rdzv/active_version"), 715 prevValue=active_version.value, 716 ) 717 718 logger.info( 719 "Destroyed rendezvous version %s successfully.", 720 expected_version, 721 ) 722 723 # We can return (and retry) immediately 724 return 725 726 # Existing rendezvous seems valid, no reason to destroy it. 727 # We just have to wait until something changes and re-check. 728 try: 729 overall_timeout = ( 730 max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 731 ) 732 self.client.watch( 733 key=self.get_path("/rdzv"), 734 index=active_version.etcd_index + 1, 735 recursive=True, 736 timeout=overall_timeout, 737 ) 738 except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): 739 pass 740 741 if time.time() > self._rendezvous_deadline: 742 raise RendezvousTimeoutError 743 active_version, state = self.get_rdzv_state() 744 745 def handle_join_last_call(self, expected_version, deadline): 746 """ 747 After we reach min number of workers, one particular worker takes on the 748 responsibility of waiting an additional timeout before closing the join window. 749 If the worker responsible for this fails, the rendezvous will be destroyed due 750 to expiring TTL, and the other participants will re-rendezvous. 751 752 Here we expect to see state <joinable, expected_version> 753 Exit gracefully if either: 754 755 1. state becomes <frozen, expected_version> 756 2. timeout happens (reaching deadline), in which case 757 we try the transition to <frozen, expected_version> 758 759 Exit with exception otherwise. 760 """ 761 active_version, state = self.get_rdzv_state() 762 while True: 763 if state["status"] == "frozen" and state["version"] == expected_version: 764 # Worker set became frozen before last-call timeout. This is possible 765 # when num_max_workers is reached before the timeout. 766 return 767 768 if state["status"] != "joinable" or state["version"] != expected_version: 769 raise EtcdRendezvousRetryableFailure( 770 "Rendezvous state transition no longer possible. Must re-enter." 771 ) 772 773 # If timeout occurred, attempt a state transition (joinable -> frozen) 774 if time.time() >= deadline: 775 state["status"] = "frozen" 776 state["keep_alives"] = [] 777 try: 778 active_version = self.client.test_and_set( 779 key=self.get_path("/rdzv/active_version"), 780 value=json.dumps(state), 781 prev_value=active_version.value, 782 ttl=CONST_ETCD_FROZEN_TTL, 783 ) 784 # We successfully made this rendezvous frozen. 785 return 786 except etcd.EtcdCompareFailed: 787 logger.info( 788 "Join last-call transition CAS unsuccessful. Will retry" 789 ) 790 cas_delay() 791 active_version, state = self.get_rdzv_state() 792 continue 793 794 # Timeout did not occur, so we must refresh TTL, and wait for 795 # further changes. Note: we only want TTL to be refreshed if 796 # state is still joinable, hence we use CAS for that here, 797 # even though we don't change any of the data. 798 try: 799 active_version = self.client.test_and_set( 800 key=self.get_path("/rdzv/active_version"), 801 value=active_version.value, 802 prev_value=active_version.value, 803 ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL, 804 ) 805 806 # Minimize "oversleeping": 807 timeout = min( 808 CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2, 809 deadline - time.time() + 1.0, # Oversleeping by 1s is ok. 810 ) 811 active_version, state = self.try_wait_for_state_change( 812 etcd_index=active_version.etcd_index + 1, timeout=timeout 813 ) 814 except etcd.EtcdCompareFailed: 815 logger.info("Join last-call TTL refresh CAS unsuccessful, will retry") 816 cas_delay() 817 active_version, state = self.get_rdzv_state() 818 819 def set_closed(self): 820 """ 821 Mark rendezvous 'closed' for current run_id, which is used to signal other 822 participants to not attempt to perform (re-)rendezvous. This is useful 823 when one of the workers decides the job is complete. 824 """ 825 while True: 826 active_version, state = self.get_rdzv_state() 827 828 if state["status"] == "closed": 829 # Already closed by someone else. 830 return 831 832 state["status"] = "closed" 833 try: 834 self.client.test_and_set( 835 key=self.get_path("/rdzv/active_version"), 836 value=json.dumps(state), 837 prev_value=active_version.value, 838 ) 839 return 840 841 except etcd.EtcdCompareFailed: 842 logger.info("Set closed CAS unsuccessful, retrying") 843 cas_delay() 844 845 def get_rdzv_state(self): 846 active_version = self.client.get(key=self.get_path("/rdzv/active_version")) 847 return active_version, json.loads(active_version.value) 848 849 def try_wait_for_state_change(self, etcd_index, timeout=None): 850 # Don't sleep past the overall deadline (at least more than by 1s) 851 overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0 852 timeout = overall_timeout if timeout is None else min(timeout, overall_timeout) 853 854 try: 855 self.client.watch( 856 self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout 857 ) 858 except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): 859 pass 860 861 if time.time() > self._rendezvous_deadline: 862 raise RendezvousTimeoutError 863 864 # Unfortunately, we have to do another fetch in order to get last etcd_index. 865 return self.get_rdzv_state() 866 867 def get_path(self, path): 868 if not path.startswith("/"): 869 path = "/" + path 870 871 return f"{self._prefix}run_{self._run_id}{path}" 872 873 def create_path_if_not_exists(self, full_path, ttl=None): 874 try: 875 self.client.write( 876 key=full_path, value=None, dir=True, prevExist=False, ttl=ttl 877 ) 878 except etcd.EtcdAlreadyExist: 879 pass 880 881 def setup_lease_renewal(self, full_path, ttl): 882 # NOTE: For ephemeral key TTL renewal (~lease) to work correctly, 883 # make sure you don't call any long-blocking methods that do not 884 # release the Python's GIL! An example of this is calling a pybind11 885 # extension function that is blocking / long-running, but is not 886 # doing a scoped release of the GIL. 887 def lease_worker(client, path, ttl, stop_event): 888 while True: 889 try: 890 client.refresh(path, ttl=ttl) 891 except etcd.EtcdKeyNotFound: 892 break 893 except ConnectionRefusedError: 894 # This error usually occurs during test when the server already got terminated but the 895 # python garbage collector have not yet invoked the __del__ method. 896 break 897 898 if stop_event.wait(timeout=ttl / 2): 899 break 900 901 lease_stop_event = threading.Event() 902 lease_thread = threading.Thread( 903 target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event) 904 ) 905 906 lease_thread.daemon = True 907 lease_thread.start() 908 909 return lease_stop_event 910 911 def store_extra_data(self, rdzv_version, key, value): 912 node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") 913 try: 914 # If first time we are storing anything: 915 extra_data = self.client.write( 916 key=node, value=json.dumps({key: value}), prevExist=False 917 ) 918 return 919 except etcd.EtcdAlreadyExist: 920 pass 921 922 # CAS loop, to make sure we don't lose concurrent stores. 923 while True: 924 # We never delete extra_data. Failure here should be fatal, no special handling. 925 extra_data = self.client.get(node) 926 927 new_extra_data_value = json.loads(extra_data.value) 928 new_extra_data_value[key] = value 929 930 try: 931 extra_data = self.client.test_and_set( 932 key=node, 933 value=json.dumps(new_extra_data_value), 934 prev_value=extra_data.value, 935 ) 936 return 937 except etcd.EtcdCompareFailed: 938 logger.info("Store extra_data CAS unsuccessful, retrying") 939 time.sleep(0.1) 940 941 def load_extra_data(self, rdzv_version, key, timeout=None): 942 # 'extra_data' node itself, and the directory it is located in: 943 node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data") 944 node_dir = self.get_path(f"/rdzv/v_{rdzv_version}") 945 946 # TODO: implement timeout 947 # https://github.com/pytorch/elastic/issues/12 948 while True: 949 # Combined wait for the node itself, and the key inside it. 950 root = self.client.get(node_dir) 951 952 # Find the extra_data node, if it exists 953 extra_data = [n for n in root.children if n.key == node] 954 assert len(extra_data) <= 1 955 956 # Node for extra_data exists, check the desired key inside it. 957 if len(extra_data) == 1: 958 extra_data_dict = json.loads(extra_data[0].value) 959 if key in extra_data_dict: 960 return extra_data_dict[key] 961 962 # The 'extra_data' node doesn't exist, or they key isn't published yet. 963 # Wait for interesting events on the extra_data node and retry. 964 try: 965 self.client.watch(node, index=root.etcd_index + 1) 966 except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut): 967 pass 968 969 def setup_kv_store(self, rdzv_version): 970 store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv") 971 self.create_path_if_not_exists(store_path) 972 return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path) 973 974 975def _create_etcd_client(params: RendezvousParameters) -> etcd.Client: 976 """Create a new ``etcd.Client`` from the specified ``RendezvousParameters``.""" 977 hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379) 978 979 # The communication protocol 980 protocol = params.config.get("protocol") 981 if protocol is None: 982 protocol = "http" 983 else: 984 if protocol != "http" and protocol != "https": 985 raise ValueError("The etcd protocol must be HTTP or HTTPS.") 986 987 # The SSL client certificate 988 ssl_cert = params.config.get("cert") 989 if ssl_cert is not None: 990 cert_key = params.config.get("key") 991 if cert_key is not None: 992 # The etcd client expects the certificate key as the second element 993 # of the `cert` tuple. 994 ssl_cert = (ssl_cert, cert_key) 995 996 # The root certificate 997 ca_cert = params.config.get("cacert") 998 999 return etcd.Client( 1000 hostname, 1001 port, 1002 protocol=protocol, 1003 cert=ssl_cert, 1004 ca_cert=ca_cert, 1005 allow_reconnect=True, 1006 ) 1007 1008 1009# Handler for torch.distributed "static" registration 1010def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: 1011 """ 1012 Usage: 1013 1014 :: 1015 1016 rdzv_params = RendezvousParameters( 1017 backend="etcd", 1018 endpoint="192.168.0.42:2379", 1019 run_id="123", 1020 min_nodes=4, 1021 max_nodes=8, 1022 timeout=300, 1023 last_call_timeout=30, 1024 etcd_prefix="custom_prefix", 1025 protocol="https", 1026 cacert="/etc/kubernetes/certs/ca.crt", 1027 cert="/etc/kubernetes/certs/client.crt", 1028 key="/etc/kubernetes/certs/client.key") 1029 # -- or -- 1030 rdzv_params = RendezvousParameters( 1031 backend="etcd", 1032 endpoint="192.168.0.42:2379", 1033 run_id="123", 1034 min_nodes=4, 1035 max_nodes=8) 1036 1037 etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params) 1038 1039 1040 Where: 1041 run_id - unique id for this training job instance, 1042 min_nodes - min number of workers expected to join the rendezvous, 1043 max_nodes - max number of workers allowed to join the rendezvous, 1044 defaults to min_workers is not specified. 1045 timeout - total timeout within which next_rendezvous is expected to 1046 succeed; a RendezvousTimeoutError is raised otherwise; 1047 Defaults is 600 (10 minutes). 1048 last_call_timeout - additional wait amount ("last call") after 1049 min number of workers has been reached. 1050 Defaults to 30 seconds. 1051 etcd_prefix - path prefix (from etcd root), inside which all 1052 etcd nodes will be created. 1053 Default is "/torchelastic/p2p". 1054 protocol - http (default) or https to access etcd. 1055 cacert - CA cert to access etcd, only makes sense with https. 1056 cert - client cert to access etcd, only makes sense with https. 1057 key - client key to access etcd, only makes sense with https. 1058 """ 1059 client = _create_etcd_client(params) 1060 1061 etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p") 1062 1063 rdzv = EtcdRendezvous( 1064 client=client, 1065 prefix=etcd_prefix, 1066 run_id=params.run_id, 1067 num_min_workers=params.min_nodes, 1068 num_max_workers=params.max_nodes, 1069 timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT), 1070 last_call_timeout=params.get_as_int( 1071 "last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT 1072 ), 1073 ) 1074 return EtcdRendezvousHandler( 1075 rdzv_impl=rdzv, 1076 local_addr=params.local_addr, 1077 ) 1078