1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import inspect 9import logging 10import os 11import pickle 12import socket 13import threading 14import time 15import weakref 16from abc import ABC, abstractmethod 17from dataclasses import dataclass 18from datetime import datetime, timedelta 19from enum import Enum 20from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple 21 22import torch.distributed as dist 23from torch.distributed import Store 24from torch.distributed.elastic.events import construct_and_record_rdzv_event, NodeState 25 26from .api import ( 27 RendezvousClosedError, 28 RendezvousError, 29 RendezvousGracefulExitError, 30 RendezvousHandler, 31 RendezvousInfo, 32 RendezvousParameters, 33 RendezvousStateError, 34 RendezvousStoreInfo, 35 RendezvousTimeoutError, 36) 37from .utils import _delay, _PeriodicTimer 38 39 40__all__ = [ 41 "RendezvousBackend", 42 "RendezvousTimeout", 43 "RendezvousSettings", 44 "DynamicRendezvousHandler", 45 "create_handler", 46] 47 48logger = logging.getLogger(__name__) 49 50 51def get_method_name(depth=2): 52 if len(inspect.stack()) > depth: 53 return inspect.stack()[depth].function 54 return "no_method_name" 55 56 57Token = Any 58"""Represent an opaque fencing token used by the rendezvous backend.""" 59 60 61class RendezvousBackend(ABC): 62 """Represent a backend that holds the rendezvous state.""" 63 64 @property 65 @abstractmethod 66 def name(self) -> str: 67 """Get the name of the backend.""" 68 69 @abstractmethod 70 def get_state(self) -> Optional[Tuple[bytes, Token]]: 71 """Get the rendezvous state. 72 73 Returns: 74 A tuple of the encoded rendezvous state and its fencing token or 75 ``None`` if no state is found in the backend. 76 77 Raises: 78 RendezvousConnectionError: 79 The connection to the backend has failed. 80 RendezvousStateError: 81 The rendezvous state is corrupt. 82 """ 83 84 @abstractmethod 85 def set_state( 86 self, state: bytes, token: Optional[Token] = None 87 ) -> Optional[Tuple[bytes, Token, bool]]: 88 """Set the rendezvous state. 89 90 The new rendezvous state is set conditionally: 91 92 - If the specified ``token`` matches the fencing token stored in the 93 backend, the state will be updated. The new state will be returned 94 to the caller along with its fencing token. 95 - If the specified ``token`` does not match the fencing token stored 96 in the backend, the state won't be updated; instead the existing 97 state along with its fencing token will be returned to the caller. 98 - If the specified ``token`` is ``None``, the new state will be set 99 only if there is no existing state in the backend. Either the new 100 state or the existing state along with its fencing token will be 101 returned to the caller. 102 103 Args: 104 state: 105 The encoded rendezvous state. 106 token: 107 An optional fencing token that was retrieved by a previous call 108 to :py:meth:`get_state` or ``set_state()``. 109 110 Returns: 111 A tuple of the serialized rendezvous state, its fencing token, and 112 a boolean value indicating whether our set attempt succeeded. 113 114 Raises: 115 RendezvousConnectionError: 116 The connection to the backend has failed. 117 RendezvousStateError: 118 The rendezvous state is corrupt. 119 """ 120 121 122class RendezvousTimeout: 123 """Hold the timeout configuration of a rendezvous. 124 125 Args: 126 join: 127 The time within which the rendezvous is expected to complete. 128 last_call: 129 An additional wait amount before completing the rendezvous once the 130 rendezvous has the minimum number of required participants. 131 close: 132 The time within which the rendezvous is expected to close after a 133 call to :py:meth:`RendezvousHandler.set_closed` or 134 :py:meth:`RendezvousHandler.shutdown`. 135 keep_alive: 136 The time within which a keep-alive heartbeat is expected to 137 complete. 138 """ 139 140 _ZERO = timedelta(0) 141 142 _DEFAULT_TIMEOUTS = { 143 "join": timedelta(seconds=600), 144 "last_call": timedelta(seconds=30), 145 "close": timedelta(seconds=30), 146 "heartbeat": timedelta(seconds=5), 147 } 148 149 _join: timedelta 150 _last_call: timedelta 151 _close: timedelta 152 _heartbeat: timedelta 153 154 def __init__( 155 self, 156 join: Optional[timedelta] = None, 157 last_call: Optional[timedelta] = None, 158 close: Optional[timedelta] = None, 159 heartbeat: Optional[timedelta] = None, 160 ) -> None: 161 self._set_timeouts( 162 join=join, last_call=last_call, close=close, heartbeat=heartbeat 163 ) 164 165 @property 166 def join(self) -> timedelta: 167 """Get the join timeout.""" 168 return self._join 169 170 @property 171 def last_call(self) -> timedelta: 172 """Get the last call timeout.""" 173 return self._last_call 174 175 @property 176 def close(self) -> timedelta: 177 """Get the close timeout.""" 178 return self._close 179 180 @property 181 def heartbeat(self) -> timedelta: 182 """Get the keep-alive heartbeat timeout.""" 183 return self._heartbeat 184 185 def _set_timeouts(self, **timeouts: Optional[timedelta]): 186 for name, timeout in timeouts.items(): 187 if timeout is None: 188 timeout = self._DEFAULT_TIMEOUTS[name] 189 if timeout <= self._ZERO: 190 raise ValueError(f"The {name} timeout ({timeout}) must be positive.") 191 setattr(self, "_" + name, timeout) 192 193 194@dataclass(repr=False, eq=False, frozen=True) 195class RendezvousSettings: 196 """Hold the settings of the rendezvous. 197 198 Attributes: 199 run_id: 200 The run id of the rendezvous. 201 min_nodes: 202 The minimum number of nodes to admit to the rendezvous. 203 max_nodes: 204 The maximum number of nodes to admit to the rendezvous. 205 timeout: 206 The timeout configuration of the rendezvous. 207 keep_alive_interval: 208 The amount of time a node waits before sending a heartbeat to keep 209 it alive in the rendezvous. 210 keep_alive_max_attempt: 211 The maximum number of failed heartbeat attempts after which a node 212 is considered dead. 213 """ 214 215 run_id: str 216 min_nodes: int 217 max_nodes: int 218 timeout: RendezvousTimeout 219 keep_alive_interval: timedelta 220 keep_alive_max_attempt: int 221 222 223@dataclass(eq=True, order=True, frozen=True) 224class _NodeDesc: 225 """Describe a node in the rendezvous. 226 227 Attributes: 228 addr: 229 The FQDN of the node or user specified local node address. 230 pid: 231 The id of the process in which the rendezvous handler runs. 232 local_id: 233 A process-wide unique id. 234 """ 235 236 addr: str 237 pid: int 238 local_id: int 239 240 def __repr__(self) -> str: 241 return f"{self.addr}_{self.pid}_{self.local_id}" 242 243 244class _NodeDescGenerator: 245 """Generate node descriptors. 246 247 A node descriptor is a combination of an FQDN, a process id, and an auto- 248 incremented integer that uniquely identifies a node in the rendezvous. 249 """ 250 251 _lock: threading.Lock 252 _local_id: int 253 254 def __init__(self) -> None: 255 self._lock = threading.Lock() 256 257 # An integer that is incremented with each call to generate(). 258 self._local_id = 0 259 260 def generate(self, local_addr: Optional[str] = None) -> _NodeDesc: 261 # This method can be called by multiple threads concurrently; therefore, 262 # we must increment the integer atomically. 263 with self._lock: 264 local_id = self._local_id 265 266 self._local_id += 1 267 268 return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id) 269 270 271class _RendezvousState: 272 """Hold the state of a rendezvous. 273 274 Attributes: 275 round: 276 The current round of the rendezvous. 277 complete: 278 A boolean value indicating whether the current round of the 279 rendezvous is complete. 280 deadline: 281 The time at which the current round of the rendezvous will be 282 considered complete if it is still waiting for nodes to join. 283 closed: 284 A boolean value indicating whether the rendezvous is closed. 285 participants: 286 A dictionary of the participants and their corresponding ranks. 287 wait_list: 288 A set of nodes that are waiting to participate in the next round of 289 the rendezvous. 290 redundancy_list: 291 A set of nodes that are redundant in the current round and can join 292 the next rendezvous without triggering re-rendezvous. 293 last_heartbeats: 294 A dictionary containing each node's last heartbeat time. 295 """ 296 297 round: int 298 complete: bool 299 deadline: Optional[datetime] 300 closed: bool 301 participants: Dict[_NodeDesc, int] 302 wait_list: Set[_NodeDesc] 303 redundancy_list: Set[_NodeDesc] 304 last_heartbeats: Dict[_NodeDesc, datetime] 305 306 def __init__(self) -> None: 307 self.round = 0 308 self.complete = False 309 self.deadline = None 310 self.closed = False 311 self.participants = {} 312 self.wait_list = set() 313 self.redundancy_list = set() 314 self.last_heartbeats = {} 315 316 317def _remove_participant_epilogue( 318 state: _RendezvousState, settings: RendezvousSettings 319) -> None: 320 if state.complete: 321 # If we do not have any participants left, move to the next round. 322 if not state.participants: 323 msg = "No participants left in the rendezvous, marking rendezvous as incomplete" 324 logger.debug(msg) 325 state.complete = False 326 327 state.round += 1 328 else: 329 if len(state.participants) < settings.min_nodes: 330 msg = ( 331 f"Number of participants {len(state.participants)}) less than" 332 f"min_nodes {settings.min_nodes}, clearning deadline in state" 333 ) 334 logger.debug(msg) 335 state.deadline = None 336 337 338class _RendezvousStateHolder(ABC): 339 """Hold the shared rendezvous state synced with other nodes.""" 340 341 @property 342 @abstractmethod 343 def state(self) -> _RendezvousState: 344 """Get the local state.""" 345 346 @abstractmethod 347 def sync(self) -> Optional[bool]: 348 """Read or writes the latest state. 349 350 Returns: 351 A boolean value indicating whether the local state, in case marked 352 as dirty, was successfully synced with other nodes. 353 """ 354 355 @abstractmethod 356 def mark_dirty(self) -> None: 357 """Mark the local state as dirty.""" 358 359 360class _BackendRendezvousStateHolder(_RendezvousStateHolder): 361 """Hold the rendezvous state synced with other nodes via a backend. 362 363 Args: 364 backend: 365 The rendezvous backend to use. 366 settings: 367 The rendezvous settings. 368 cache_duration: 369 The amount of time, in seconds, to cache the last rendezvous state 370 before requesting it from the backend again. 371 """ 372 373 _backend: RendezvousBackend 374 _state: _RendezvousState 375 _settings: RendezvousSettings 376 _cache_duration: int 377 _token: Token 378 _dirty: bool 379 _last_sync_time: float 380 _dead_nodes: List[_NodeDesc] 381 382 def __init__( 383 self, 384 backend: RendezvousBackend, 385 settings: RendezvousSettings, 386 cache_duration: int = 1, 387 ) -> None: 388 self._backend = backend 389 self._state = _RendezvousState() 390 self._settings = settings 391 self._cache_duration = cache_duration 392 self._token = None 393 self._dirty = False 394 self._last_sync_time = -1 395 self._dead_nodes = [] 396 397 def _record(self, message: str, node_state: NodeState = NodeState.RUNNING): 398 construct_and_record_rdzv_event( 399 name=f"{self.__class__.__name__}.{get_method_name()}", 400 run_id=self._settings.run_id, 401 message=message, 402 node_state=node_state, 403 ) 404 405 @property 406 def state(self) -> _RendezvousState: 407 """See base class.""" 408 return self._state 409 410 def sync(self) -> Optional[bool]: 411 """See base class.""" 412 state_bits: Optional[bytes] = None 413 414 token = None 415 416 has_set: Optional[bool] 417 418 if self._dirty: 419 has_set = False 420 421 state_bits = pickle.dumps(self._state) 422 423 set_response = self._backend.set_state(state_bits, self._token) 424 if set_response is not None: 425 state_bits, token, has_set = set_response 426 else: 427 has_set = None 428 429 if self._cache_duration > 0: 430 # Avoid overloading the backend if we are asked to retrieve the 431 # state repeatedly. Try to serve the cached state. 432 if self._last_sync_time >= max( 433 time.monotonic() - self._cache_duration, 0 434 ): 435 return None 436 437 get_response = self._backend.get_state() 438 if get_response is not None: 439 state_bits, token = get_response 440 441 if state_bits is not None: 442 try: 443 self._state = pickle.loads(state_bits) 444 except pickle.PickleError as exc: 445 raise RendezvousStateError( 446 "The rendezvous state is corrupt. See inner exception for details." 447 ) from exc 448 else: 449 self._state = _RendezvousState() 450 451 if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG): 452 node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes) 453 454 msg = ( 455 f"As part of the sync operation the node(s) {node_list} have been removed from the " 456 f"rendezvous '{self._settings.run_id}' since they had no heartbeat." 457 ) 458 self._record(message=msg) 459 logger.debug(msg) 460 461 self._token = token 462 463 self._dirty = False 464 465 self._last_sync_time = time.monotonic() 466 467 self._sanitize() 468 469 return has_set 470 471 def _sanitize(self) -> None: 472 state = self._state 473 474 expire_time = datetime.utcnow() - ( 475 self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt 476 ) 477 478 # Filter out the dead nodes. 479 self._dead_nodes = [ 480 node 481 for node, last_heartbeat in state.last_heartbeats.items() 482 if last_heartbeat < expire_time 483 ] 484 485 participant_removed = False 486 487 for dead_node in self._dead_nodes: 488 msg = f"Detected dead node '{dead_node}', removing it from the rendezvous" 489 logger.debug(msg) 490 del state.last_heartbeats[dead_node] 491 492 try: 493 del state.participants[dead_node] 494 495 participant_removed = True 496 except KeyError: 497 pass 498 499 try: 500 state.wait_list.remove(dead_node) 501 except KeyError: 502 pass 503 504 try: 505 state.redundancy_list.remove(dead_node) 506 except KeyError: 507 pass 508 509 if participant_removed: 510 # Common epilogue shared with the _remove_from_participants() 511 # function of _DistributedRendezvousOpExecutor. 512 _remove_participant_epilogue(state, self._settings) 513 514 def mark_dirty(self) -> None: 515 """See base class. 516 517 If the local rendezvous state is dirty, the next sync call will try to 518 write the changes back to the backend. However this attempt might fail 519 if another node, which had the same state, also made changes and wrote 520 them before us. 521 """ 522 self._dirty = True 523 524 525class _Action(Enum): 526 """Specifies the possible actions based on the state of the rendezvous.""" 527 528 KEEP_ALIVE = 1 529 ADD_TO_PARTICIPANTS = 2 530 ADD_TO_WAIT_LIST = 3 531 ADD_TO_REDUNDANCY_LIST = 4 532 REMOVE_FROM_PARTICIPANTS = 5 533 REMOVE_FROM_WAIT_LIST = 6 534 REMOVE_FROM_REDUNDANCY_LIST = 7 535 MARK_RENDEZVOUS_COMPLETE = 8 536 MARK_RENDEZVOUS_CLOSED = 9 537 SYNC = 10 538 ERROR_CLOSED = 11 539 ERROR_TIMEOUT = 12 540 FINISH = 13 541 542 543class _RendezvousContext: 544 """Holds the context of the rendezvous. 545 546 Attributes: 547 node: 548 The node descriptor associated with the current rendezvous handler 549 instance. 550 state: 551 The current state of the rendezvous. 552 settings: 553 The rendezvous settings. 554 """ 555 556 node: _NodeDesc 557 state: _RendezvousState 558 settings: RendezvousSettings 559 560 def __init__( 561 self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings 562 ) -> None: 563 self.node = node 564 self.state = state 565 self.settings = settings 566 567 568class _RendezvousOpExecutor(ABC): 569 """Execute rendezvous operations.""" 570 571 @abstractmethod 572 def run( 573 self, 574 state_handler: Callable[[_RendezvousContext, float], _Action], 575 deadline: float, 576 update_deadline: Optional[Callable[[timedelta], float]] = None, 577 ) -> None: 578 """Execute a rendezvous operation. 579 580 An operation is run inside a state machine and is expected to transition 581 the rendezvous from one state to another. 582 583 Args: 584 state_handler: 585 A callable that is expected to return the next state transition 586 action based on the current state of the rendezvous. 587 deadline: 588 The time, in seconds, at which the operation will be considered 589 timed-out. 590 update_deadline: 591 Function to generate a new operation deadline if the current 592 node may participate in the next rendezvous. 593 """ 594 595 596class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor): 597 """Execute rendezvous operations using a shared state. 598 599 Args: 600 node: 601 The node descriptor associated with the current rendezvous handler 602 instance. 603 state_holder: 604 The ``RendezvousStateHolder`` to use to sync the rendezvous state 605 with other nodes. 606 settings: 607 The rendezvous settings. 608 """ 609 610 _node: _NodeDesc 611 _state: _RendezvousState 612 _state_holder: _RendezvousStateHolder 613 _settings: RendezvousSettings 614 615 def __init__( 616 self, 617 node: _NodeDesc, 618 state_holder: _RendezvousStateHolder, 619 settings: RendezvousSettings, 620 ) -> None: 621 self._node = node 622 self._state_holder = state_holder 623 self._settings = settings 624 625 def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None: 626 construct_and_record_rdzv_event( 627 name=f"{self.__class__.__name__}.{get_method_name()}", 628 run_id=self._settings.run_id, 629 message=message, 630 node_state=node_state, 631 hostname=self._node.addr, 632 pid=self._node.pid, 633 local_id=self._node.local_id, 634 ) 635 636 def run( 637 self, 638 state_handler: Callable[[_RendezvousContext, float], _Action], 639 deadline: float, 640 update_deadline: Optional[Callable[[timedelta], float]] = None, 641 ) -> None: 642 """See base class.""" 643 action = None 644 while action != _Action.FINISH: 645 # Reads or writes the latest rendezvous state shared by all nodes in 646 # the rendezvous. Note that our local changes might get overridden 647 # by another node if that node synced its changes before us. 648 has_set = self._state_holder.sync() 649 if has_set is not None: 650 if has_set: 651 msg = ( 652 f"The node '{self._node}' has successfully synced its local changes with " 653 f"other nodes in the rendezvous '{self._settings.run_id}'." 654 ) 655 else: 656 msg = ( 657 f"The node '{self._node}' has a stale state and failed to sync its local " 658 f"changes with other nodes in the rendezvous '{self._settings.run_id}'." 659 ) 660 661 self._record(message=msg) 662 logger.debug(msg) 663 664 self._state = self._state_holder.state 665 666 ctx = _RendezvousContext(self._node, self._state, self._settings) 667 668 # Determine the next action to take based on the current state of 669 # the rendezvous. 670 action = state_handler(ctx, deadline) 671 672 if action == _Action.FINISH: 673 continue 674 675 if action == _Action.ERROR_CLOSED: 676 raise RendezvousClosedError 677 678 if action == _Action.ERROR_TIMEOUT: 679 raise RendezvousTimeoutError 680 681 if action == _Action.SYNC: 682 # Delay the execution by one second to avoid overloading the 683 # backend if we are asked to poll for state changes. 684 _delay(seconds=1) 685 else: 686 if action == _Action.KEEP_ALIVE: 687 self._keep_alive() 688 elif action == _Action.ADD_TO_PARTICIPANTS: 689 self._add_to_participants() 690 elif action == _Action.ADD_TO_WAIT_LIST: 691 self._add_to_wait_list() 692 elif action == _Action.ADD_TO_REDUNDANCY_LIST: 693 self._add_to_redundancy_list() 694 elif action == _Action.REMOVE_FROM_PARTICIPANTS: 695 self._remove_from_participants() 696 elif action == _Action.REMOVE_FROM_WAIT_LIST: 697 self._remove_from_wait_list() 698 elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST: 699 self._remove_from_redundancy_list() 700 # update deadline since the node may participate in rendezvous process 701 if update_deadline: 702 deadline = update_deadline(self._settings.timeout.join) 703 elif action == _Action.MARK_RENDEZVOUS_COMPLETE: 704 self._mark_rendezvous_complete() 705 elif action == _Action.MARK_RENDEZVOUS_CLOSED: 706 self._mark_rendezvous_closed() 707 708 # Attempt to sync our changes back to other nodes. 709 self._state_holder.mark_dirty() 710 711 def _keep_alive(self) -> None: 712 msg = ( 713 f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous " 714 f"'{self._settings.run_id}'. Pending sync." 715 ) 716 self._record(message=msg) 717 logger.debug(msg) 718 719 self._state.last_heartbeats[self._node] = datetime.utcnow() 720 721 def _add_to_participants(self) -> None: 722 msg = ( 723 f"The node '{self._node}' added itself to the participants of round " 724 f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." 725 ) 726 self._record(message=msg) 727 logger.debug(msg) 728 729 state = self._state 730 731 try: 732 state.wait_list.remove(self._node) 733 except KeyError: 734 pass 735 736 # The ranks of the participants will be set once the rendezvous is 737 # complete. 738 state.participants[self._node] = 0 739 740 self._keep_alive() 741 742 if len(state.participants) == self._settings.min_nodes: 743 state.deadline = datetime.utcnow() + self._settings.timeout.last_call 744 745 if len(state.participants) == self._settings.max_nodes: 746 self._mark_rendezvous_complete() 747 748 def _add_to_wait_list(self) -> None: 749 msg = ( 750 f"The node '{self._node}' added itself to the wait list of round " 751 f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." 752 ) 753 self._record(message=msg) 754 logger.debug(msg) 755 756 if self._node in self._state.redundancy_list: 757 self._state.redundancy_list.remove(self._node) 758 self._state.wait_list.add(self._node) 759 760 self._keep_alive() 761 762 def _add_to_redundancy_list(self) -> None: 763 msg = ( 764 f"The node '{self._node}' added itself to the redundancy list of round " 765 f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." 766 ) 767 self._record(message=msg) 768 logger.debug(msg) 769 770 self._state.redundancy_list.add(self._node) 771 772 self._keep_alive() 773 774 def _remove_from_participants(self) -> None: 775 msg = ( 776 f"The node '{self._node}' removed itself from the participants of round " 777 f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync." 778 ) 779 self._record(message=msg) 780 logger.debug(msg) 781 782 state = self._state 783 784 del state.participants[self._node] 785 786 del state.last_heartbeats[self._node] 787 788 # Common epilogue shared with the sanitizer() function of 789 # _BackendRendezvousStateHolder. 790 _remove_participant_epilogue(state, self._settings) 791 792 def _remove_from_wait_list(self) -> None: 793 msg = ( 794 f"The node '{self._node}' removed itself from the wait list of round " 795 f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." 796 ) 797 self._record(message=msg) 798 logger.debug(msg) 799 800 self._state.wait_list.remove(self._node) 801 802 del self._state.last_heartbeats[self._node] 803 804 def _remove_from_redundancy_list(self) -> None: 805 msg = ( 806 f"The node '{self._node}' removed itself from the redunant list of round " 807 f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync." 808 ) 809 self._record(message=msg) 810 logger.debug(msg) 811 812 self._state.redundancy_list.remove(self._node) 813 814 del self._state.last_heartbeats[self._node] 815 816 def _mark_rendezvous_complete(self) -> None: 817 msg = ( 818 f"The node '{self._node}' marked round {self._state.round} of the rendezvous " 819 f"'{self._settings.run_id}' as complete. Pending sync." 820 ) 821 self._record(message=msg, node_state=NodeState.SUCCEEDED) 822 logger.debug(msg) 823 824 state = self._state 825 826 state.complete = True 827 state.deadline = None 828 829 # Assign the ranks. 830 for rank, node in enumerate(sorted(state.participants)): 831 state.participants[node] = rank 832 833 def _mark_rendezvous_closed(self) -> None: 834 msg = ( 835 f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. " 836 "Pending sync." 837 ) 838 self._record(message=msg, node_state=NodeState.SUCCEEDED) 839 logger.debug(msg) 840 841 self._state.closed = True 842 843 844def _should_keep_alive(ctx: _RendezvousContext) -> bool: 845 """Determine whether a keep-alive heartbeat should be sent.""" 846 try: 847 last_heartbeat = ctx.state.last_heartbeats[ctx.node] 848 except KeyError: 849 return False 850 851 return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval 852 853 854class _RendezvousExitOp: 855 """Represent a rendezvous exit operation.""" 856 857 def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: 858 if ctx.node in ctx.state.participants: 859 if time.monotonic() > deadline: 860 return _Action.ERROR_TIMEOUT 861 return _Action.REMOVE_FROM_PARTICIPANTS 862 return _Action.FINISH 863 864 865class _RendezvousJoinOp: 866 """Represent a rendezvous join operation.""" 867 868 def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: 869 state = ctx.state 870 871 # A closed rendezvous means that it no longer accepts new nodes. 872 if state.closed: 873 if ctx.node in state.redundancy_list: 874 msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous." 875 raise RendezvousGracefulExitError(msg) 876 return _Action.ERROR_CLOSED 877 878 if ctx.node in state.redundancy_list: 879 msg = f"The node {ctx.node} is in redunancy list" 880 logger.debug(msg) 881 # don't apply the timeout logic here, since we want to allow the node to rejoin 882 if len(state.participants) == ctx.settings.max_nodes: 883 if _should_keep_alive(ctx): 884 return _Action.KEEP_ALIVE 885 else: 886 return _Action.SYNC 887 else: 888 # transition to waiting state that will respect timeouts. 889 msg = f"The node {ctx.node} is removed from redunancy list" 890 logger.debug(msg) 891 return _Action.REMOVE_FROM_REDUNDANCY_LIST 892 893 is_participant = ctx.node in state.participants 894 895 # If we are part of the rendezvous and it is already complete there is 896 # no further action to take. 897 if state.complete and is_participant: 898 return _Action.FINISH 899 900 now = time.monotonic() 901 if now > deadline: 902 rollback_period = 5 # 5 seconds 903 904 # If we still have time to rollback (a short period on top of the 905 # operation deadline), try to remove ourself from the rendezvous. 906 # It is okay if we can't though as our keep-alive will eventually 907 # expire. 908 if now <= deadline + rollback_period: 909 # If we are part of the rendezvous, it means we couldn't find 910 # enough participants to complete it on time. 911 if is_participant: 912 return _Action.REMOVE_FROM_PARTICIPANTS 913 # If we are in the wait list, it means we couldn't wait till the 914 # next round of the rendezvous. 915 if ctx.node in state.wait_list: 916 return _Action.REMOVE_FROM_WAIT_LIST 917 return _Action.ERROR_TIMEOUT 918 919 if state.complete: 920 # If we are here, it means we are not part of the rendezvous. In 921 # case the rendezvous has capacity for additional participants add 922 # ourself to the wait list for the next round. 923 if len(state.participants) < ctx.settings.max_nodes: 924 if ctx.node not in state.wait_list: 925 return _Action.ADD_TO_WAIT_LIST 926 elif len(state.participants) >= ctx.settings.max_nodes: 927 if ( 928 ctx.node not in state.redundancy_list 929 and ctx.node not in state.wait_list 930 ): 931 return _Action.ADD_TO_REDUNDANCY_LIST 932 elif is_participant: 933 # If the rendezvous has enough number of participants including us, 934 # check whether we have passed the rendezvous deadline. If yes, 935 # complete it. 936 if ( 937 len(state.participants) >= ctx.settings.min_nodes 938 and len(state.participants) <= ctx.settings.max_nodes 939 ): 940 if cast(datetime, state.deadline) < datetime.utcnow(): 941 msg = ( 942 f"The node '{ctx.node}' marking the rendezvous complete, " 943 f"quorum established within deadline" 944 ) 945 logger.debug(msg) 946 return _Action.MARK_RENDEZVOUS_COMPLETE 947 else: 948 msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached" 949 logger.debug(msg) 950 else: 951 msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants" 952 logger.debug(msg) 953 else: 954 # The rendezvous is not complete yet and we are not part of it. Try 955 # to join. 956 return _Action.ADD_TO_PARTICIPANTS 957 958 if _should_keep_alive(ctx): 959 return _Action.KEEP_ALIVE 960 961 # At this point either the rendezvous is not complete, but we are part 962 # of it, which means we have to wait for other participants to join; or 963 # the rendezvous is complete, but we are not part of it, which means we 964 # have to wait for the next round. 965 return _Action.SYNC 966 967 968class _RendezvousCloseOp: 969 """Represent a rendezvous close operation.""" 970 971 def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: 972 if ctx.state.closed: 973 return _Action.FINISH 974 if time.monotonic() > deadline: 975 return _Action.ERROR_TIMEOUT 976 return _Action.MARK_RENDEZVOUS_CLOSED 977 978 979class _RendezvousKeepAliveOp: 980 """Represent a rendezvous keep-alive update operation.""" 981 982 def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action: 983 if _should_keep_alive(ctx): 984 if time.monotonic() > deadline: 985 return _Action.ERROR_TIMEOUT 986 return _Action.KEEP_ALIVE 987 return _Action.FINISH 988 989 990class DynamicRendezvousHandler(RendezvousHandler): 991 """Represent a handler that sets up a rendezvous among a set of nodes.""" 992 993 # Static 994 _node_desc_generator = _NodeDescGenerator() 995 996 _this_node: _NodeDesc 997 _settings: RendezvousSettings 998 _backend_name: str 999 _store: Store 1000 _state_holder: _RendezvousStateHolder 1001 _op_executor: _RendezvousOpExecutor 1002 _heartbeat_lock: threading.Lock 1003 _keep_alive_timer: Optional[_PeriodicTimer] 1004 1005 @classmethod 1006 def from_backend( 1007 cls, 1008 run_id: str, 1009 store: Store, 1010 backend: RendezvousBackend, 1011 min_nodes: int, 1012 max_nodes: int, 1013 local_addr: Optional[str] = None, 1014 timeout: Optional[RendezvousTimeout] = None, 1015 ): 1016 """Create a new :py:class:`DynamicRendezvousHandler`. 1017 1018 Args: 1019 run_id: 1020 The run id of the rendezvous. 1021 store: 1022 The C10d store to return as part of the rendezvous. 1023 backend: 1024 The backend to use to hold the rendezvous state. 1025 min_nodes: 1026 The minimum number of nodes to admit to the rendezvous. 1027 max_nodes: 1028 The maximum number of nodes to admit to the rendezvous. 1029 local_addr: 1030 The local node address. 1031 timeout: 1032 The timeout configuration of the rendezvous. 1033 """ 1034 # We associate each handler instance with a unique node descriptor. 1035 node = cls._node_desc_generator.generate(local_addr) 1036 1037 settings = RendezvousSettings( 1038 run_id, 1039 min_nodes, 1040 max_nodes, 1041 timeout or RendezvousTimeout(), 1042 keep_alive_interval=timedelta(seconds=5), 1043 keep_alive_max_attempt=3, 1044 ) 1045 1046 state_holder = _BackendRendezvousStateHolder(backend, settings) 1047 1048 return cls(node, settings, backend.name, store, state_holder) 1049 1050 def __init__( 1051 self, 1052 node: _NodeDesc, 1053 settings: RendezvousSettings, 1054 backend_name: str, 1055 store: Store, 1056 state_holder: _RendezvousStateHolder, 1057 ) -> None: 1058 if not settings.run_id: 1059 raise ValueError("The run id must be a non-empty string.") 1060 1061 if settings.min_nodes < 1: 1062 raise ValueError( 1063 f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero." 1064 ) 1065 1066 if settings.max_nodes < settings.min_nodes: 1067 raise ValueError( 1068 f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal " 1069 f"to the minimum number of nodes ({settings.min_nodes})." 1070 ) 1071 1072 self._this_node = node 1073 1074 self._settings = settings 1075 1076 self._backend_name = backend_name 1077 1078 self._store = store 1079 1080 self._state_holder = state_holder 1081 1082 self._op_executor = _DistributedRendezvousOpExecutor( 1083 self._this_node, self._state_holder, self._settings 1084 ) 1085 1086 self._heartbeat_lock = threading.Lock() 1087 1088 self._keep_alive_timer = None 1089 1090 # Cached shared store server reference 1091 self._shared_tcp_store_server: Optional[dist.Store] = None 1092 1093 self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None 1094 1095 def _record( 1096 self, 1097 message: str, 1098 node_state: NodeState = NodeState.RUNNING, 1099 rank: Optional[int] = None, 1100 ) -> None: 1101 construct_and_record_rdzv_event( 1102 name=f"{self.__class__.__name__}.{get_method_name()}", 1103 run_id=self._settings.run_id, 1104 message=message, 1105 node_state=node_state, 1106 hostname=self._this_node.addr, 1107 pid=self._this_node.pid, 1108 local_id=self._this_node.local_id, 1109 rank=rank, 1110 ) 1111 1112 def _create_tcp_store_server(self, bootstrap_store_info) -> dist.TCPStore: 1113 return dist.TCPStore( 1114 bootstrap_store_info.master_addr, 1115 bootstrap_store_info.master_port, 1116 is_master=True, 1117 multi_tenant=True, 1118 ) 1119 1120 @property 1121 def settings(self) -> RendezvousSettings: 1122 """Get the settings of the rendezvous.""" 1123 return self._settings 1124 1125 def get_backend(self) -> str: 1126 """See base class.""" 1127 return self._backend_name 1128 1129 @property 1130 def use_agent_store(self) -> bool: 1131 """See base class.""" 1132 return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1" 1133 1134 def next_rendezvous(self) -> RendezvousInfo: 1135 """See base class.""" 1136 msg = ( 1137 f"The node '{self._this_node}' attempts to join the next round of the rendezvous " 1138 f"'{self._settings.run_id}'." 1139 ) 1140 self._record(message=msg) 1141 logger.info(msg) 1142 1143 try: 1144 self._stop_heartbeats() 1145 1146 # Delay the execution for a small random amount of time if this is our 1147 # first run. This will slightly skew the rendezvous attempts across the 1148 # nodes and reduce the load on the backend. 1149 if self._state_holder.state.round == 0: 1150 _delay(seconds=(0, 0.3)) 1151 1152 exit_op = _RendezvousExitOp() 1153 join_op = _RendezvousJoinOp() 1154 1155 deadline = self._get_deadline(self._settings.timeout.join) 1156 self._op_executor.run(exit_op, deadline) 1157 self._op_executor.run(join_op, deadline, self._get_deadline) 1158 1159 self._start_heartbeats() 1160 1161 rank, world_size = self._get_world() 1162 store = self._get_store() 1163 1164 except Exception as e: 1165 self._record( 1166 message=f"{type(e).__name__}: {str(e)}", 1167 node_state=NodeState.FAILED, 1168 ) 1169 raise 1170 1171 msg = ( 1172 f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of " 1173 f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size " 1174 f"{world_size}." 1175 ) 1176 self._record(message=msg, rank=rank) 1177 logger.info(msg) 1178 1179 # opt-out option of TCP store sharing 1180 if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1": 1181 bootstrap_store_info = RendezvousStoreInfo.build( 1182 rank, store, local_addr=self._this_node.addr 1183 ) 1184 return RendezvousInfo( 1185 store, 1186 rank, 1187 world_size, 1188 bootstrap_store_info, 1189 ) 1190 1191 if self._bootstrap_store_info is None: 1192 if isinstance(self._store, dist.TCPStore): 1193 addr = self._store.host 1194 port = self._store.port 1195 self._bootstrap_store_info = RendezvousStoreInfo( 1196 master_addr=addr, master_port=port 1197 ) 1198 if rank == 0: 1199 self._shared_tcp_store_server = self._store 1200 else: 1201 # If the store is not type of TCPStore start TCPStore server, which requries 1202 # bootstrapping info across ranks 1203 self._bootstrap_store_info = RendezvousStoreInfo.build( 1204 rank, store, local_addr=self._this_node.addr 1205 ) 1206 if rank == 0: 1207 self._shared_tcp_store_server = self._create_tcp_store_server( 1208 self._bootstrap_store_info 1209 ) 1210 1211 assert self._bootstrap_store_info is not None 1212 if rank == 0: 1213 assert self._shared_tcp_store_server is not None 1214 1215 return RendezvousInfo( 1216 store, 1217 rank, 1218 world_size, 1219 self._bootstrap_store_info, # type: ignore[assignment] 1220 ) 1221 1222 def is_closed(self) -> bool: 1223 """See base class.""" 1224 try: 1225 with self._heartbeat_lock: 1226 self._state_holder.sync() 1227 1228 return self._state_holder.state.closed 1229 1230 except Exception as e: 1231 self._record( 1232 message=f"{type(e).__name__}: {str(e)}", 1233 node_state=NodeState.FAILED, 1234 ) 1235 raise 1236 1237 def set_closed(self) -> None: 1238 """See base class.""" 1239 try: 1240 with self._heartbeat_lock: 1241 self._close() 1242 except Exception as e: 1243 self._record( 1244 message=f"{type(e).__name__}: {str(e)}", 1245 node_state=NodeState.FAILED, 1246 ) 1247 raise 1248 1249 def num_nodes_waiting(self) -> int: 1250 """See base class.""" 1251 try: 1252 with self._heartbeat_lock: 1253 self._state_holder.sync() 1254 1255 return len(self._state_holder.state.wait_list) 1256 1257 except Exception as e: 1258 self._record( 1259 message=f"{type(e).__name__}: {str(e)}", 1260 node_state=NodeState.FAILED, 1261 ) 1262 raise 1263 1264 def get_run_id(self) -> str: 1265 """See base class.""" 1266 return self._settings.run_id 1267 1268 def shutdown(self) -> bool: 1269 """See base class.""" 1270 self._stop_heartbeats() 1271 1272 try: 1273 self._close() 1274 1275 return True 1276 except RendezvousError as ex: 1277 msg = ( 1278 f"The node '{self._this_node}' has failed to shutdown the rendezvous " 1279 f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}." 1280 ) 1281 self._record(message=msg, node_state=NodeState.FAILED) 1282 logger.warning(msg) 1283 1284 return False 1285 except Exception as e: 1286 self._record( 1287 message=f"{type(e).__name__}: {str(e)}", 1288 node_state=NodeState.FAILED, 1289 ) 1290 raise 1291 1292 def _close(self) -> None: 1293 op = _RendezvousCloseOp() 1294 1295 deadline = self._get_deadline(self._settings.timeout.close) 1296 1297 self._op_executor.run(op, deadline) 1298 1299 msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'." 1300 self._record(message=msg, node_state=NodeState.SUCCEEDED) 1301 logger.info(msg) 1302 1303 @staticmethod 1304 def _keep_alive_weak(weak_self) -> None: 1305 self = weak_self() 1306 if self is not None: 1307 self._keep_alive() 1308 1309 def _keep_alive(self) -> None: 1310 self._heartbeat_lock.acquire() 1311 1312 op = _RendezvousKeepAliveOp() 1313 1314 deadline = self._get_deadline(self._settings.timeout.heartbeat) 1315 1316 try: 1317 self._op_executor.run(op, deadline) 1318 1319 msg = ( 1320 f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous " 1321 f"'{self._settings.run_id}'." 1322 ) 1323 self._record(message=msg) 1324 logger.debug(msg) 1325 except RendezvousError as ex: 1326 msg = ( 1327 f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the " 1328 f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}." 1329 ) 1330 self._record(message=msg, node_state=NodeState.FAILED) 1331 logger.warning(msg) 1332 finally: 1333 self._heartbeat_lock.release() 1334 1335 def _start_heartbeats(self) -> None: 1336 self._keep_alive_timer = _PeriodicTimer( 1337 self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self) 1338 ) 1339 1340 self._keep_alive_timer.set_name( 1341 f"RendezvousKeepAliveTimer_{self._this_node.local_id}" 1342 ) 1343 1344 self._keep_alive_timer.start() 1345 1346 def _stop_heartbeats(self) -> None: 1347 if self._keep_alive_timer is None: 1348 return 1349 1350 self._keep_alive_timer.cancel() 1351 1352 def _get_world(self) -> Tuple[int, int]: 1353 state = self._state_holder.state 1354 1355 return state.participants[self._this_node], len(state.participants) 1356 1357 def _wrap_store(self, store: Store) -> Store: 1358 key_prefix = ( 1359 f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}" 1360 ) 1361 1362 return dist.PrefixStore(key_prefix, store) 1363 1364 def _get_store(self) -> Store: 1365 return self._wrap_store(self._store) 1366 1367 def _get_deadline(self, timeout: timedelta) -> float: 1368 return time.monotonic() + timeout.total_seconds() 1369 1370 1371def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]: 1372 timeout = params.get_as_int(key + "_timeout") 1373 if timeout is None: 1374 return None 1375 return timedelta(seconds=timeout) 1376 1377 1378def create_handler( 1379 store: Store, backend: RendezvousBackend, params: RendezvousParameters 1380) -> DynamicRendezvousHandler: 1381 """Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters. 1382 1383 Args: 1384 store: 1385 The C10d store to return as part of the rendezvous. 1386 backend: 1387 The backend to use to hold the rendezvous state. 1388 1389 +-------------------+------------------------------------------------------+ 1390 | Parameter | Description | 1391 +===================+======================================================+ 1392 | join_timeout | The total time, in seconds, within which the | 1393 | | rendezvous is expected to complete. Defaults to 600 | 1394 | | seconds. | 1395 +-------------------+------------------------------------------------------+ 1396 | last_call_timeout | An additional wait amount, in seconds, before | 1397 | | completing the rendezvous once the minimum number of | 1398 | | nodes has been reached. Defaults to 30 seconds. | 1399 +-------------------+------------------------------------------------------+ 1400 | close_timeout | The time, in seconds, within which the rendezvous is | 1401 | | expected to close after a call to | 1402 | | :py:meth:`RendezvousHandler.set_closed` or | 1403 | | :py:meth:`RendezvousHandler.shutdown`. Defaults to | 1404 | | 30 seconds. | 1405 +-------------------+------------------------------------------------------+ 1406 """ 1407 try: 1408 timeout = RendezvousTimeout( 1409 _get_timeout(params, "join"), 1410 _get_timeout(params, "last_call"), 1411 _get_timeout(params, "close"), 1412 ) 1413 1414 return DynamicRendezvousHandler.from_backend( 1415 params.run_id, 1416 store, 1417 backend, 1418 params.min_nodes, 1419 params.max_nodes, 1420 params.local_addr, 1421 timeout, 1422 ) 1423 except Exception as e: 1424 construct_and_record_rdzv_event( 1425 message=f"{type(e).__name__}: {str(e)}", 1426 run_id=params.run_id, 1427 node_state=NodeState.FAILED, 1428 ) 1429 raise 1430