1# mypy: ignore-errors 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import abc 10import json 11import os 12import signal 13import socket 14import time 15import traceback 16import warnings 17from collections import defaultdict 18from contextlib import contextmanager 19from dataclasses import dataclass, field 20from enum import Enum 21from typing import Any, Callable, Dict, List, Optional, Tuple, Union 22 23import torch.distributed.elastic.rendezvous as rdzv 24import torch.distributed.elastic.utils.store as store_util 25from torch.distributed.elastic.events import Event, EventSource, record 26from torch.distributed.elastic.metrics import prof, put_metric 27from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException 28from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError 29from torch.distributed.elastic.utils.logging import get_logger 30 31 32__all__ = [ 33 "WorkerSpec", 34 "Worker", 35 "WorkerState", 36 "WorkerGroup", 37 "RunResult", 38 "ElasticAgent", 39 "SimpleElasticAgent", 40] 41_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state" 42 43DEFAULT_ROLE = "default" 44logger = get_logger(__name__) 45 46 47@dataclass 48class WorkerSpec: 49 """Blueprint information about a particular type of worker. 50 51 For a given role, there must only exist a single worker spec. 52 Worker spec is expected to be homogeneous across all nodes (machine), 53 that is each node runs the same number of workers for a particular spec. 54 55 Args: 56 role: user-defined role for the workers with this spec 57 local_world_size: number local workers to run 58 fn: (deprecated use entrypoint instead) 59 entrypoint: worker function or command 60 args: arguments to pass to ``entrypoint`` 61 rdzv_handler: handles rdzv for this set of workers 62 max_restarts: number of max retries for the workers 63 monitor_interval: monitor status of workers every ``n`` seconds 64 master_port: fixed port to run the c10d store on rank 0 65 if not specified then will chose a random free port 66 master_addr: fixed master_addr to run the c10d store on rank 0 67 if not specified then will chose hostname on agent rank 0 68 redirects: redirect std streams to a file, 69 selectively redirect for a particular 70 local rank by passing a map 71 tee: tees the specified std stream(s) to console + file, 72 selectively tee for a particular local rank by passing a map, 73 takes precedence over ``redirects`` settings. 74 75 """ 76 77 role: str 78 local_world_size: int 79 rdzv_handler: rdzv.RendezvousHandler 80 fn: Optional[Callable] = None 81 # TODO @kiuk - make entrypoint a required field 82 entrypoint: Union[Callable, str, None] = None 83 args: Tuple = () 84 max_restarts: int = 3 85 monitor_interval: float = 0.1 86 master_port: Optional[int] = None 87 master_addr: Optional[str] = None 88 local_addr: Optional[str] = None 89 90 def __post_init__(self): 91 assert self.local_world_size > 0 92 assert self.monitor_interval > 0 93 94 if self.fn: 95 warnings.warn( 96 "WorkerSpec.fn will be deprecated," 97 " please use WorkerSpec.entrypoint instead", 98 category=DeprecationWarning, 99 ) 100 self.entrypoint = self.fn 101 assert self.entrypoint 102 103 def get_entrypoint_name(self): 104 """Get the entry point name. 105 106 If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__`` 107 else if the entrypoint is a binary (e.g. ``str``), returns the binary name. 108 """ 109 if isinstance(self.entrypoint, str): 110 return os.path.basename(self.entrypoint) 111 else: 112 assert self.entrypoint is not None 113 return self.entrypoint.__qualname__ 114 115 116class Worker: 117 """A worker instance. 118 119 Contrast this with ``WorkerSpec`` that represents the specifications of a 120 worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to 121 a ``WorkerSpec`` as an object is to a class. 122 123 The ``id`` of the worker is interpreted 124 by the specific implementation of ``ElasticAgent``. For a local 125 agent, it could be the ``pid (int)`` of the worker, for a remote 126 agent it could be encoded as ``host:port (string)``. 127 128 Args: 129 id (Any): uniquely identifies a worker (interpreted by the agent) 130 local_rank (int): local rank of the worker 131 global_rank (int): global rank of the worker 132 role_rank (int): rank of the worker across all workers that have the same role 133 world_size (int): number of workers (globally) 134 role_world_size (int): number of workers that have the same role 135 """ 136 137 __slots__ = [ 138 "id", 139 "local_rank", 140 "global_rank", 141 "role_rank", 142 "world_size", 143 "role_world_size", 144 ] 145 146 def __init__( 147 self, 148 local_rank: int, 149 global_rank: int = -1, 150 role_rank: int = -1, 151 world_size: int = -1, 152 role_world_size: int = -1, 153 ): 154 # unique identifier for this worker 155 self.id: Any = None 156 157 # rank of the worker among workers with the same role being monitored 158 # by the same ``agent`` instance. 159 self.local_rank: int = local_rank 160 161 # rank of the worker among all the workers across all roles 162 # across all ``agent`` instances. 163 # Global rank is not stable between re-rendezvous. 164 self.global_rank: int = global_rank 165 166 # rank of the worker among all the workers with the same role 167 # across all ``agent`` instances. 168 # Role rank is not stable between re-rendezvous. 169 self.role_rank: int = role_rank 170 171 # total number of workers (globally). Due to elasticity 172 # the world size may change between re-rendezvous. 173 self.world_size: int = world_size 174 175 # total number of workers that share the same role. Due to elasticity 176 # the role world size may change between re-rendezvous. 177 self.role_world_size: int = role_world_size 178 179 def __str__(self): 180 return ( 181 f"local_rank={self.local_rank},global_rank={self.global_rank}" 182 f",role_rank={self.role_rank},world_size={self.world_size}" 183 f",role_world_size={self.role_world_size}" 184 ) 185 186 def __repr__(self): 187 return str(self) 188 189 190class WorkerState(str, Enum): 191 """A state of the ``WorkerGroup``. 192 193 Workers in a worker group change state as a unit. If a single worker 194 in a worker group fails the entire set is considered failed:: 195 196 UNKNOWN - agent lost track of worker group state, unrecoverable 197 INIT - worker group object created not yet started 198 HEALTHY - workers running and healthy 199 UNHEALTHY - workers running and unhealthy 200 STOPPED - workers stopped (interrupted) by the agent 201 SUCCEEDED - workers finished running (exit 0) 202 FAILED - workers failed to successfully finish (exit !0) 203 204 205 A worker group starts from an initial ``INIT`` state, 206 then progresses to ``HEALTHY`` or ``UNHEALTHY`` states, 207 and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state. 208 209 Worker groups can be interrupted and temporarily put into ``STOPPED`` state 210 by the agent. Workers in ``STOPPED`` state are scheduled to be restarted 211 in the near future by the agent. Some examples of workers being put into 212 ``STOPPED`` state are: 213 214 1. Worker group failure|unhealthy observed 215 2. Membership change detected 216 217 When actions (start, stop, rdzv, retry, etc) on worker group fails 218 and results in the action being partially applied to the worker group 219 the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled 220 exceptions during state change events on the agent. The agent is not 221 expected to recover worker groups in ``UNKNOWN`` state and is better off 222 self terminating and allowing the job manager to retry the node. 223 """ 224 225 UNKNOWN = "UNKNOWN" 226 INIT = "INIT" 227 HEALTHY = "HEALTHY" 228 UNHEALTHY = "UNHEALTHY" 229 STOPPED = "STOPPED" 230 SUCCEEDED = "SUCCEEDED" 231 FAILED = "FAILED" 232 233 @staticmethod 234 def is_running(state: "WorkerState") -> bool: 235 """Return the state of the Worker. 236 237 Returns: 238 True if the worker state represents workers still running 239 (e.g. that the process exists but not necessarily healthy). 240 """ 241 return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY} 242 243 244class WorkerGroup: 245 """A set of ``Worker`` instances. 246 247 The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker 248 group contains cross instance workers or not depends on the implementation of the agent. 249 """ 250 251 __slots__ = [ 252 "spec", 253 "workers", 254 "store", 255 "group_rank", 256 "group_world_size", 257 "state", 258 "master_addr", 259 "master_port", 260 ] 261 262 def __init__(self, spec: WorkerSpec): 263 self.spec = spec 264 self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)] 265 266 # assigned after rdzv 267 self.store = None 268 self.group_rank = None 269 self.group_world_size = None 270 self.master_addr = None 271 self.master_port = None 272 273 self.state = WorkerState.INIT 274 275 276class _RoleInstanceInfo: 277 """The class is used by the agent to exchange the information with other agents. 278 279 The information is used to determine the rank of the workers that agent 280 manages in heterogeneous environments, where different agents can have 281 different number of workers. 282 """ 283 284 __slots__ = ["role", "rank", "local_world_size"] 285 286 def __init__(self, role: str, rank: int, local_world_size: int): 287 r"""Initialize the agent class instance. 288 289 Args: 290 role (str): user-defined role for the workers with this spec 291 rank (int): the rank of the agent 292 local_world_size (int): number of local workers to run 293 """ 294 self.role = role 295 self.rank = rank 296 self.local_world_size = local_world_size 297 298 def serialize(self) -> bytes: 299 dict_data = { 300 "role": self.role, 301 "rank": self.rank, 302 "local_world_size": self.local_world_size, 303 } 304 return json.dumps(dict_data).encode(encoding="UTF-8") 305 306 @staticmethod 307 def deserialize(data: bytes): 308 dict_data = json.loads(data.decode(encoding="UTF-8")) 309 return _RoleInstanceInfo( 310 dict_data["role"], dict_data["rank"], dict_data["local_world_size"] 311 ) 312 313 @staticmethod 314 def compare(obj1, obj2) -> int: 315 if obj1.role == obj2.role: 316 return obj1.rank - obj2.rank 317 elif obj1.role > obj2.role: 318 return 1 319 else: 320 return -1 321 322 @staticmethod 323 def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]: 324 start_idx, end_idx = -1, -1 325 for idx, role_info in enumerate(roles_infos): 326 if role_info.role == role: 327 if start_idx == -1: 328 start_idx = idx 329 end_idx = idx 330 return (start_idx, end_idx) 331 332 333@dataclass 334class RunResult: 335 """Return results of the worker executions. 336 337 Run results follow an "all-or-nothing" policy where the run is successful if and 338 only if ALL local workers managed by this agent complete successfully. 339 340 If the result is successful (e.g. ``is_failed() = False``) then the ``return_values`` 341 field contains the outputs (return values) of the workers managed by THIS agent mapped 342 by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of 343 global rank 0. 344 345 .. note:: ``return_values`` are only meaningful for when the worker entrypoint 346 is a function. Workers specified as a binary entrypoint do not canonically 347 have a return value and the ``return_values`` field is meaningless and 348 may be empty. 349 350 If ``is_failed()`` returns ``True`` then the ``failures`` field contains the 351 failure information, again, mapped by the GLOBAL rank of the worker that failed. 352 353 The keys in ``return_values`` and ``failures`` are mutually exclusive, that is, 354 a worker's final state can only be one of: succeeded, failed. Workers intentionally 355 terminated by the agent according to the agent's restart policy, are not represented 356 in either ``return_values`` nor ``failures``. 357 """ 358 359 state: WorkerState 360 return_values: Dict[int, Any] = field(default_factory=dict) 361 failures: Dict[int, ProcessFailure] = field(default_factory=dict) 362 363 def is_failed(self) -> bool: 364 return self.state == WorkerState.FAILED 365 366 367def _get_fq_hostname() -> str: 368 return socket.getfqdn(socket.gethostname()) 369 370 371class ElasticAgent(abc.ABC): 372 """An agent process responsible for managing one or more worker processes. 373 374 The worker processes are assumed to be regular distributed PyTorch scripts. 375 When the worker process is created by the agent, the agent provides the 376 necessary information for the worker processes to properly initialize 377 a torch process group. 378 379 The exact deployment topology and ratio of agent-to-worker is dependent 380 on the specific implementation of the agent and the user's job placement 381 preferences. For instance, to run a distributed training job on GPU with 382 8 trainers (one per GPU) one can: 383 384 1. Use 8 x single GPU instances, place an agent per instance, managing 385 1 worker per agent. 386 2. Use 4 x double GPU instances, place an agent per instance, managing 387 2 workers per agent. 388 3. Use 2 x quad GPU instances, place an agent per instance, managing 389 4 workers per agent. 390 4. Use 1 x 8 GPU instance, place an agent per instance, managing 391 8 workers per agent. 392 393 Usage 394 :: 395 396 group_result = agent.run() 397 if group_result.is_failed(): 398 # workers failed 399 failure = group_result.failures[0] 400 logger.exception("worker 0 failed with exit code : %s", failure.exit_code) 401 else: 402 return group_result.return_values[0] # return rank 0's results 403 404 """ 405 406 @abc.abstractmethod 407 def run(self, role: str = DEFAULT_ROLE) -> RunResult: 408 """Run the agent. 409 410 Supports retrying the worker group on failures up to ``max_restarts``. 411 412 Returns: 413 The result of the execution, containing the return values or 414 failure details for each worker mapped by the worker's global rank. 415 416 Raises: 417 Exception - any other failures NOT related to worker process 418 """ 419 raise NotImplementedError 420 421 @abc.abstractmethod 422 def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: 423 """Return the ``WorkerGroup`` for the given ``role``. 424 425 Note that the worker group is a mutable object and hence in a 426 multi-threaded/process environment it may change state. 427 Implementors are encouraged (but not required) to return 428 a defensive read-only copy. 429 """ 430 raise NotImplementedError 431 432 433class SimpleElasticAgent(ElasticAgent): 434 """An ``ElasticAgent`` that manages one particular type of worker role. 435 436 An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec`` 437 such as one particular type of worker role. 438 """ 439 440 def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300): 441 self._worker_group = WorkerGroup(spec) 442 self._remaining_restarts = self._worker_group.spec.max_restarts 443 self._store = None 444 self._exit_barrier_timeout = exit_barrier_timeout 445 self._total_execution_time = 0 446 447 def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: 448 return self._worker_group 449 450 @abc.abstractmethod 451 def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: 452 r"""Start ``worker_group.spec.local_world_size`` number of workers. 453 454 This is according to worker spec for the worker group . 455 Returns a map of ``local_rank`` to worker ``id``. 456 """ 457 raise NotImplementedError 458 459 @abc.abstractmethod 460 def _stop_workers( 461 self, worker_group: WorkerGroup, is_restart: bool = False 462 ) -> None: 463 r"""Stop all workers in the given worker group. 464 465 Implementors must deal with workers in all states defined by 466 ``WorkerState``. That is, it must gracefully handle stopping 467 non-existent workers, unhealthy (stuck) workers, etc. 468 """ 469 raise NotImplementedError 470 471 @abc.abstractmethod 472 def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: 473 r"""Check on the workers for the ``worker_group``. 474 475 This function also returns the new state of the worker group. 476 """ 477 raise NotImplementedError 478 479 @abc.abstractmethod 480 def _shutdown( 481 self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False 482 ) -> None: 483 """Clean up any resources that were allocated during the agent's work. 484 485 Args: 486 death_sig: Signal to send to the child process, SIGTERM is default 487 """ 488 raise NotImplementedError 489 490 @prof 491 def _rendezvous(self, worker_group: WorkerGroup) -> None: 492 r"""Run rendezvous for the workers specified by the worker spec. 493 494 Assigns workers a new global rank and world size. 495 Updates the rendezvous store for the worker group. 496 """ 497 spec = worker_group.spec 498 499 with self.record_duration("RENDEZVOUS"): 500 rdzv_info = spec.rdzv_handler.next_rendezvous() 501 store = rdzv_info.store 502 group_rank = rdzv_info.rank 503 group_world_size = rdzv_info.world_size 504 505 # master_addr/master_port could be explicitly overriden 506 # TODO: BC - specific to static rdzv and can be simplifed further 507 master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr 508 master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port 509 510 self._store = store 511 512 with self.record_duration("ASSIGN_WORKER_RANKS"): 513 workers = self._assign_worker_ranks( 514 store, group_rank, group_world_size, spec 515 ) 516 worker_group.workers = workers 517 worker_group.store = store 518 worker_group.group_rank = group_rank 519 worker_group.group_world_size = group_world_size 520 worker_group.master_addr = master_addr 521 worker_group.master_port = master_port 522 523 restart_count = spec.max_restarts - self._remaining_restarts 524 525 logger.info( 526 "[%(role)s] Rendezvous complete for workers. Result:\n" 527 " restart_count=%(restart_count)s\n" 528 " master_addr=%(master_addr)s\n" 529 " master_port=%(master_port)s\n" 530 " group_rank=%(group_rank)s\n" 531 " group_world_size=%(group_world_size)s\n" 532 " local_ranks=%(local_ranks)s\n" 533 " role_ranks=%(role_ranks)s\n" 534 " global_ranks=%(global_ranks)s\n" 535 " role_world_sizes=%(role_world_sizes)s\n" 536 " global_world_sizes=%(global_world_sizes)s\n", 537 { 538 "role": spec.role, 539 "restart_count": restart_count, 540 "master_addr": master_addr, 541 "master_port": master_port, 542 "group_rank": group_rank, 543 "group_world_size": group_world_size, 544 "local_ranks": [worker.local_rank for worker in workers], 545 "role_ranks": [worker.role_rank for worker in workers], 546 "global_ranks": [worker.global_rank for worker in workers], 547 "role_world_sizes": [worker.role_world_size for worker in workers], 548 "global_world_sizes": [worker.world_size for worker in workers], 549 }, 550 ) 551 552 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 553 # `torch.distributed.elastic.metrics.prof`. 554 @prof 555 def _assign_worker_ranks( 556 self, store, group_rank: int, group_world_size: int, spec: WorkerSpec 557 ) -> List[Worker]: 558 """Determine proper ranks for worker processes. 559 560 The rank assignment is done according to the following algorithm: 561 562 1. Each agent writes its configuration(group_rank, group_world_size 563 , num_workers) to the common store. 564 2. The rank 0 agent reads all the role_info from the store and 565 determines each agents worker ranks. 566 3. Determine the global rank: the global rank of the workers is computed 567 by cumulative sum of the local_world_size for all workers in front of it. 568 For efficiency reasons each worker is assigned a base global rank 569 such that it's workers are in the range [base_global_rank, 570 base_global_rank + local_world_size). 571 4. Determine the role rank: The role rank is determined using the algorithms 572 in the point 3 with the exception that the ranks are calculated with 573 respect to the role name. 574 5. The rank 0 agent writes the assigned ranks to the store. 575 6. Each agent reads the assigned ranks from the store. 576 577 Time complexity: each worker O(1), rank0 O(n), overall O(n) 578 """ 579 580 ROLE_INFO_PREFIX = "torchelastic/role_info/" 581 ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/" 582 583 agent_role_info = _RoleInstanceInfo( 584 spec.role, group_rank, spec.local_world_size 585 ) 586 store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize()) 587 588 # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations. 589 if group_rank == 0: 590 role_infos_bytes = store.multi_get( 591 [f"torchelastic/role_info/{i}" for i in range(group_world_size)] 592 ) 593 role_infos = [ 594 _RoleInstanceInfo.deserialize(info_bytes) 595 for info_bytes in role_infos_bytes 596 ] 597 598 role_sizes = defaultdict(lambda: 0) 599 global_size = 0 600 for role_info in role_infos: 601 role_sizes[role_info.role] += role_info.local_world_size 602 global_size += role_info.local_world_size 603 604 base_global_rank = 0 605 role_ranks = defaultdict(lambda: 0) 606 607 keys = [] 608 values = [] 609 for i, role_info in enumerate(role_infos): 610 keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}") 611 values.append( 612 json.dumps( 613 [ 614 base_global_rank, 615 global_size, 616 role_ranks[role_info.role], 617 role_sizes[role_info.role], 618 ] 619 ) 620 ) 621 622 base_global_rank += role_info.local_world_size 623 role_ranks[role_info.role] += role_info.local_world_size 624 625 store.multi_set(keys, values) 626 627 # get will block until the data is available in the store. 628 ( 629 base_global_rank, 630 global_world_size, 631 base_role_rank, 632 role_world_size, 633 ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}")) 634 635 workers = [] 636 for local_rank in range(spec.local_world_size): 637 worker = Worker( 638 local_rank=local_rank, 639 global_rank=base_global_rank + local_rank, 640 role_rank=base_role_rank + local_rank, 641 world_size=global_world_size, 642 role_world_size=role_world_size, 643 ) 644 workers.append(worker) 645 return workers 646 647 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 648 # `torch.distributed.elastic.metrics.prof`. 649 @prof 650 def _initialize_workers(self, worker_group: WorkerGroup) -> None: 651 r"""Start a fresh set of workers for the worker_group. 652 653 Essentially, a rendezvous followed by a ``start_workers``. 654 The caller should first call ``_stop_workers()`` to stop running workers 655 prior to calling this method. 656 657 Optimistically sets the state of the worker group that 658 just started as ``HEALTHY`` and delegates the actual monitoring 659 of state to ``_monitor_workers()`` method 660 """ 661 role = worker_group.spec.role 662 logger.info("[%s] Rendezvous'ing worker group", role) 663 664 # TODO after stopping workers, wait at least monitor_interval*2 for 665 # workers on different nodes to fail on a collective op before waiting 666 # on the rdzv barrier, this way we ensure that nodes enter rdzv 667 # at around the same time and reduce false positive rdzv timeout errors 668 self._rendezvous(worker_group) 669 670 logger.info("[%s] Starting worker group", role) 671 worker_ids = self._start_workers(worker_group) 672 for local_rank, w_id in worker_ids.items(): 673 worker = worker_group.workers[local_rank] 674 worker.id = w_id 675 676 worker_group.state = WorkerState.HEALTHY 677 678 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 679 # `torch.distributed.elastic.metrics.prof`. 680 @prof 681 def _restart_workers(self, worker_group: WorkerGroup) -> None: 682 """Restart (stops, rendezvous, starts) all local workers in the group.""" 683 role = worker_group.spec.role 684 logger.info("[%s] Stopping worker group", role) 685 self._stop_workers(worker_group, is_restart=True) 686 worker_group.state = WorkerState.STOPPED 687 self._initialize_workers(worker_group) 688 689 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 690 # `torch.distributed.elastic.metrics.prof`. 691 @prof 692 def run(self, role: str = DEFAULT_ROLE) -> RunResult: 693 start_time = time.monotonic() 694 shutdown_called: bool = False 695 try: 696 result = self._invoke_run(role) 697 self._total_execution_time = int(time.monotonic() - start_time) 698 self._record_metrics(result) 699 self._record_worker_events(result) 700 return result 701 except RendezvousGracefulExitError as e: 702 logger.info("Rendezvous gracefully exited: %s", e) 703 except SignalException as e: 704 logger.warning("Received %s death signal, shutting down workers", e.sigval) 705 self._shutdown(e.sigval) 706 shutdown_called = True 707 raise 708 finally: 709 if not shutdown_called: 710 self._shutdown() 711 # record the execution time in case there were any exceptions during run. 712 self._total_execution_time = int(time.monotonic() - start_time) 713 714 def get_event_failed(self) -> Event: 715 return self._construct_event( 716 state="FAILED", 717 source=EventSource.AGENT, 718 raw_error=traceback.format_exc(), 719 ) 720 721 def get_event_succeeded(self) -> Event: 722 return self._construct_event( 723 state="SUCCEEDED", 724 source=EventSource.AGENT, 725 ) 726 727 def _record_worker_events(self, result: RunResult) -> None: 728 for worker in self._worker_group.workers: 729 failure = result.failures.get(worker.global_rank) 730 state: str = self._get_worker_state(worker, result) 731 raw_error = json.dumps(failure.error_file_data) if failure else None 732 record(self._construct_event(state, EventSource.WORKER, worker, raw_error)) 733 734 def _get_worker_state(self, worker: Worker, result: RunResult) -> str: 735 failure = result.failures.get(worker.global_rank) 736 if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure: 737 # The worker got terminated by the torchelastic agent via SIGTERM signal 738 return "TERMINATED" 739 elif failure or worker.global_rank in result.return_values: 740 return result.state.value 741 else: 742 raise ValueError(f"Unknown worker: {worker.global_rank}") 743 744 @contextmanager 745 def record_duration(self, state: str): 746 start_time = time.perf_counter() 747 try: 748 yield 749 finally: 750 end_time = time.perf_counter() 751 duration_ms = (end_time - start_time) * 1000 752 record( 753 self._construct_event( 754 state=state, source=EventSource.AGENT, duration_ms=duration_ms 755 ) 756 ) 757 758 def _construct_event( 759 self, 760 state: str, 761 source: EventSource, 762 worker: Optional[Worker] = None, 763 raw_error: Optional[str] = None, 764 duration_ms: Optional[float] = None, 765 ) -> Event: 766 wg = self._worker_group 767 spec = wg.spec 768 md = { 769 "group_world_size": wg.group_world_size, 770 "entry_point": spec.get_entrypoint_name(), 771 } 772 if worker: 773 md["local_rank"] = (worker.local_rank,) 774 md["role_rank"] = (worker.role_rank,) 775 md["role_world_size"] = (worker.role_world_size,) 776 global_rank = worker.global_rank 777 worker_id = str(worker.id) 778 else: 779 global_rank = None 780 worker_id = None 781 md_str = json.dumps(md) 782 metadata = { 783 "run_id": spec.rdzv_handler.get_run_id(), 784 "global_rank": global_rank, 785 "group_rank": wg.group_rank, 786 "worker_id": worker_id, 787 "role": spec.role, 788 "hostname": _get_fq_hostname(), 789 "state": state, 790 "total_run_time": self._total_execution_time, 791 "rdzv_backend": spec.rdzv_handler.get_backend(), 792 "raw_error": raw_error, 793 "metadata": md_str, 794 "agent_restarts": spec.max_restarts - self._remaining_restarts, 795 "duration_ms": duration_ms, 796 } 797 return Event( 798 f"torchelastic.worker.status.{state}", source=source, metadata=metadata 799 ) 800 801 def _record_metrics(self, group_results: RunResult): 802 is_failed = group_results.is_failed() 803 self._record_flakiness_metric(is_failed) 804 spec = self._worker_group.spec 805 restarts_happened = self._remaining_restarts != spec.max_restarts 806 put_metric(f"workers.{spec.role}.run_total", 1) 807 self._record_metric_with_condition( 808 "run_success_with_retries", not is_failed and restarts_happened 809 ) 810 self._record_metric_with_condition( 811 "run_success_no_retries", not is_failed and not restarts_happened 812 ) 813 self._record_metric_with_condition( 814 "run_failed_with_retries", is_failed and restarts_happened 815 ) 816 self._record_metric_with_condition( 817 "run_failed_no_retries", is_failed and not restarts_happened 818 ) 819 820 def _record_metric_with_condition(self, metric_name, condition): 821 spec = self._worker_group.spec 822 if condition: 823 put_metric(f"workers.{spec.role}.{metric_name}", 1) 824 else: 825 put_metric(f"workers.{spec.role}.{metric_name}", 0) 826 827 def _record_flakiness_metric(self, is_failed: bool = False): 828 if is_failed: 829 flakiness = 100.0 830 else: 831 spec = self._worker_group.spec 832 flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / ( 833 spec.max_restarts + 1 834 ) 835 spec = self._worker_group.spec 836 837 put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) 838 839 def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: 840 # NOTE: currently only works for a single role 841 842 spec = self._worker_group.spec 843 role = spec.role 844 845 logger.info( 846 "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name() 847 ) 848 849 self._initialize_workers(self._worker_group) 850 monitor_interval = spec.monitor_interval 851 rdzv_handler = spec.rdzv_handler 852 853 while True: 854 assert self._worker_group.state != WorkerState.INIT 855 time.sleep(monitor_interval) 856 run_result = self._monitor_workers(self._worker_group) 857 state = run_result.state 858 self._worker_group.state = state 859 860 put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts) 861 put_metric(f"workers.{role}.{state.name.lower()}", 1) 862 863 if state == WorkerState.SUCCEEDED: 864 logger.info( 865 "[%s] worker group successfully finished." 866 " Waiting %s seconds for other agents to finish.", 867 role, 868 self._exit_barrier_timeout, 869 ) 870 self._exit_barrier() 871 return run_result 872 elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: 873 if self._remaining_restarts > 0: 874 logger.info( 875 "[%s] Worker group %s. " 876 "%s/%s attempts left;" 877 " will restart worker group", 878 role, 879 state.name, 880 self._remaining_restarts, 881 spec.max_restarts, 882 ) 883 self._remaining_restarts -= 1 884 self._restart_workers(self._worker_group) 885 else: 886 self._stop_workers(self._worker_group) 887 self._worker_group.state = WorkerState.FAILED 888 return run_result 889 elif state == WorkerState.HEALTHY: 890 # membership changes do not count as retries 891 num_nodes_waiting = rdzv_handler.num_nodes_waiting() 892 group_rank = self._worker_group.group_rank 893 if num_nodes_waiting > 0: 894 logger.info( 895 "[%s] Detected %s " 896 "new nodes from group_rank=%s; " 897 "will restart worker group", 898 role, 899 num_nodes_waiting, 900 group_rank, 901 ) 902 self._restart_workers(self._worker_group) 903 else: 904 raise Exception( # noqa: TRY002 905 f"[{role}] Worker group in {state.name} state" 906 ) 907 908 def _exit_barrier(self): 909 """ 910 Define a barrier that keeps the agent process alive until all workers finish. 911 912 Wait for ``exit_barrier_timeout`` seconds for all agents to finish 913 executing their local workers (either successfully or not). This 914 acts as a safety guard against user scripts that terminate at different 915 times. 916 """ 917 logger.info( 918 "Local worker group finished (%s). " 919 "Waiting %s seconds for other agents to finish", 920 self._worker_group.state, 921 self._exit_barrier_timeout, 922 ) 923 start = time.time() 924 try: 925 store_util.barrier( 926 store=self._store, 927 world_size=self._worker_group.group_world_size, 928 key_prefix=_TERMINAL_STATE_SYNC_ID, 929 barrier_timeout=self._exit_barrier_timeout, 930 ) 931 logger.info( 932 "Done waiting for other agents. Elapsed: %s seconds", 933 time.time() - start, 934 ) 935 except SignalException as e: 936 logger.warning("Got termination signal: %s", e.sigval) 937 raise 938 except Exception: 939 logger.exception( 940 "Error waiting on exit barrier. Elapsed: %s seconds", 941 time.time() - start, 942 ) 943