xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/agent/server/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import abc
10import json
11import os
12import signal
13import socket
14import time
15import traceback
16import warnings
17from collections import defaultdict
18from contextlib import contextmanager
19from dataclasses import dataclass, field
20from enum import Enum
21from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
23import torch.distributed.elastic.rendezvous as rdzv
24import torch.distributed.elastic.utils.store as store_util
25from torch.distributed.elastic.events import Event, EventSource, record
26from torch.distributed.elastic.metrics import prof, put_metric
27from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException
28from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
29from torch.distributed.elastic.utils.logging import get_logger
30
31
32__all__ = [
33    "WorkerSpec",
34    "Worker",
35    "WorkerState",
36    "WorkerGroup",
37    "RunResult",
38    "ElasticAgent",
39    "SimpleElasticAgent",
40]
41_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
42
43DEFAULT_ROLE = "default"
44logger = get_logger(__name__)
45
46
47@dataclass
48class WorkerSpec:
49    """Blueprint information about a particular type of worker.
50
51    For a given role, there must only exist a single worker spec.
52    Worker spec is expected to be homogeneous across all nodes (machine),
53    that is each node runs the same number of workers for a particular spec.
54
55    Args:
56        role: user-defined role for the workers with this spec
57        local_world_size: number local workers to run
58        fn: (deprecated use entrypoint instead)
59        entrypoint: worker function or command
60        args: arguments to pass to ``entrypoint``
61        rdzv_handler: handles rdzv for this set of workers
62        max_restarts: number of max retries for the workers
63        monitor_interval: monitor status of workers every ``n`` seconds
64        master_port: fixed port to run the c10d store on rank 0
65                     if not specified then will chose a random free port
66        master_addr: fixed master_addr to run the c10d store on rank 0
67                     if not specified then will chose hostname on agent rank 0
68        redirects: redirect std streams to a file,
69                   selectively redirect for a particular
70                   local rank by passing a map
71        tee: tees the specified std stream(s) to console + file,
72             selectively tee for a particular local rank by passing a map,
73             takes precedence over ``redirects`` settings.
74
75    """
76
77    role: str
78    local_world_size: int
79    rdzv_handler: rdzv.RendezvousHandler
80    fn: Optional[Callable] = None
81    # TODO @kiuk - make entrypoint a required field
82    entrypoint: Union[Callable, str, None] = None
83    args: Tuple = ()
84    max_restarts: int = 3
85    monitor_interval: float = 0.1
86    master_port: Optional[int] = None
87    master_addr: Optional[str] = None
88    local_addr: Optional[str] = None
89
90    def __post_init__(self):
91        assert self.local_world_size > 0
92        assert self.monitor_interval > 0
93
94        if self.fn:
95            warnings.warn(
96                "WorkerSpec.fn will be deprecated,"
97                " please use WorkerSpec.entrypoint instead",
98                category=DeprecationWarning,
99            )
100            self.entrypoint = self.fn
101        assert self.entrypoint
102
103    def get_entrypoint_name(self):
104        """Get the entry point name.
105
106        If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__``
107        else if the entrypoint is a binary (e.g. ``str``), returns the binary name.
108        """
109        if isinstance(self.entrypoint, str):
110            return os.path.basename(self.entrypoint)
111        else:
112            assert self.entrypoint is not None
113            return self.entrypoint.__qualname__
114
115
116class Worker:
117    """A worker instance.
118
119    Contrast this with ``WorkerSpec`` that represents the specifications of a
120    worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to
121    a ``WorkerSpec`` as an object is to a class.
122
123    The ``id`` of the worker is interpreted
124    by the specific implementation of ``ElasticAgent``. For a local
125    agent, it could be the ``pid (int)`` of the worker, for a remote
126    agent it could be encoded as ``host:port (string)``.
127
128    Args:
129        id (Any): uniquely identifies a worker (interpreted by the agent)
130        local_rank (int): local rank of the worker
131        global_rank (int): global rank of the worker
132        role_rank (int): rank of the worker across all workers that have the same role
133        world_size (int): number of workers (globally)
134        role_world_size (int): number of workers that have the same role
135    """
136
137    __slots__ = [
138        "id",
139        "local_rank",
140        "global_rank",
141        "role_rank",
142        "world_size",
143        "role_world_size",
144    ]
145
146    def __init__(
147        self,
148        local_rank: int,
149        global_rank: int = -1,
150        role_rank: int = -1,
151        world_size: int = -1,
152        role_world_size: int = -1,
153    ):
154        # unique identifier for this worker
155        self.id: Any = None
156
157        # rank of the worker among workers with the same role being monitored
158        # by the same ``agent`` instance.
159        self.local_rank: int = local_rank
160
161        #  rank of the worker among all the workers across all roles
162        #  across all ``agent`` instances.
163        #  Global rank is not stable between re-rendezvous.
164        self.global_rank: int = global_rank
165
166        #  rank of the worker among all the workers with the same role
167        #  across all ``agent`` instances.
168        #  Role rank is not stable between re-rendezvous.
169        self.role_rank: int = role_rank
170
171        # total number of workers (globally). Due to elasticity
172        # the world size may change between re-rendezvous.
173        self.world_size: int = world_size
174
175        # total number of workers that share the same role. Due to elasticity
176        # the role world size may change between re-rendezvous.
177        self.role_world_size: int = role_world_size
178
179    def __str__(self):
180        return (
181            f"local_rank={self.local_rank},global_rank={self.global_rank}"
182            f",role_rank={self.role_rank},world_size={self.world_size}"
183            f",role_world_size={self.role_world_size}"
184        )
185
186    def __repr__(self):
187        return str(self)
188
189
190class WorkerState(str, Enum):
191    """A state of the ``WorkerGroup``.
192
193    Workers in a worker group change state as a unit. If a single worker
194    in a worker group fails the entire set is considered failed::
195
196      UNKNOWN - agent lost track of worker group state, unrecoverable
197      INIT - worker group object created not yet started
198      HEALTHY - workers running and healthy
199      UNHEALTHY - workers running and unhealthy
200      STOPPED - workers stopped (interrupted) by the agent
201      SUCCEEDED - workers finished running (exit 0)
202      FAILED - workers failed to successfully finish (exit !0)
203
204
205    A worker group starts from an initial ``INIT`` state,
206    then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
207    and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.
208
209    Worker groups can be interrupted and temporarily put into ``STOPPED`` state
210    by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
211    in the near future by the agent. Some examples of workers being put into
212    ``STOPPED`` state are:
213
214    1. Worker group failure|unhealthy observed
215    2. Membership change detected
216
217    When actions (start, stop, rdzv, retry, etc) on worker group fails
218    and results in the action being partially applied to the worker group
219    the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
220    exceptions during state change events on the agent. The agent is not
221    expected to recover worker groups in ``UNKNOWN`` state and is better off
222    self terminating and allowing the job manager to retry the node.
223    """
224
225    UNKNOWN = "UNKNOWN"
226    INIT = "INIT"
227    HEALTHY = "HEALTHY"
228    UNHEALTHY = "UNHEALTHY"
229    STOPPED = "STOPPED"
230    SUCCEEDED = "SUCCEEDED"
231    FAILED = "FAILED"
232
233    @staticmethod
234    def is_running(state: "WorkerState") -> bool:
235        """Return the state of the Worker.
236
237        Returns:
238             True if the worker state represents workers still running
239             (e.g. that the process exists but not necessarily healthy).
240        """
241        return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
242
243
244class WorkerGroup:
245    """A set of ``Worker`` instances.
246
247    The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker
248    group contains cross instance workers or not depends on the implementation of the agent.
249    """
250
251    __slots__ = [
252        "spec",
253        "workers",
254        "store",
255        "group_rank",
256        "group_world_size",
257        "state",
258        "master_addr",
259        "master_port",
260    ]
261
262    def __init__(self, spec: WorkerSpec):
263        self.spec = spec
264        self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
265
266        # assigned after rdzv
267        self.store = None
268        self.group_rank = None
269        self.group_world_size = None
270        self.master_addr = None
271        self.master_port = None
272
273        self.state = WorkerState.INIT
274
275
276class _RoleInstanceInfo:
277    """The class is used by the agent to exchange the information with other agents.
278
279    The information is used to determine the rank of the workers that agent
280    manages in heterogeneous environments, where different agents can have
281    different number of workers.
282    """
283
284    __slots__ = ["role", "rank", "local_world_size"]
285
286    def __init__(self, role: str, rank: int, local_world_size: int):
287        r"""Initialize the agent class instance.
288
289        Args:
290            role (str): user-defined role for the workers with this spec
291            rank (int): the rank of the agent
292            local_world_size (int): number of local workers to run
293        """
294        self.role = role
295        self.rank = rank
296        self.local_world_size = local_world_size
297
298    def serialize(self) -> bytes:
299        dict_data = {
300            "role": self.role,
301            "rank": self.rank,
302            "local_world_size": self.local_world_size,
303        }
304        return json.dumps(dict_data).encode(encoding="UTF-8")
305
306    @staticmethod
307    def deserialize(data: bytes):
308        dict_data = json.loads(data.decode(encoding="UTF-8"))
309        return _RoleInstanceInfo(
310            dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
311        )
312
313    @staticmethod
314    def compare(obj1, obj2) -> int:
315        if obj1.role == obj2.role:
316            return obj1.rank - obj2.rank
317        elif obj1.role > obj2.role:
318            return 1
319        else:
320            return -1
321
322    @staticmethod
323    def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
324        start_idx, end_idx = -1, -1
325        for idx, role_info in enumerate(roles_infos):
326            if role_info.role == role:
327                if start_idx == -1:
328                    start_idx = idx
329                end_idx = idx
330        return (start_idx, end_idx)
331
332
333@dataclass
334class RunResult:
335    """Return results of the worker executions.
336
337    Run results follow an "all-or-nothing" policy where the run is successful if and
338    only if ALL local workers managed by this agent complete successfully.
339
340    If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
341    field contains the outputs (return values) of the workers managed by THIS agent mapped
342    by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
343    global rank 0.
344
345    .. note:: ``return_values`` are only meaningful for when the worker entrypoint
346              is a function. Workers specified as a binary entrypoint do not canonically
347              have a return value and the ``return_values`` field is meaningless and
348              may be empty.
349
350    If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
351    failure information, again, mapped by the GLOBAL rank of the worker that failed.
352
353    The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
354    a worker's final state can only be one of: succeeded, failed. Workers intentionally
355    terminated by the agent according to the agent's restart policy, are not represented
356    in either ``return_values`` nor ``failures``.
357    """
358
359    state: WorkerState
360    return_values: Dict[int, Any] = field(default_factory=dict)
361    failures: Dict[int, ProcessFailure] = field(default_factory=dict)
362
363    def is_failed(self) -> bool:
364        return self.state == WorkerState.FAILED
365
366
367def _get_fq_hostname() -> str:
368    return socket.getfqdn(socket.gethostname())
369
370
371class ElasticAgent(abc.ABC):
372    """An agent process responsible for managing one or more worker processes.
373
374    The worker processes are assumed to be regular distributed PyTorch scripts.
375    When the worker process is created by the agent, the agent provides the
376    necessary information for the worker processes to properly initialize
377    a torch process group.
378
379    The exact deployment topology and ratio of agent-to-worker is dependent
380    on the specific implementation of the agent and the user's job placement
381    preferences. For instance, to run a distributed training job on GPU with
382    8 trainers (one per GPU) one can:
383
384    1. Use 8 x single GPU instances, place an agent per instance, managing
385       1 worker per agent.
386    2. Use 4 x double GPU instances, place an agent per instance, managing
387       2 workers per agent.
388    3. Use 2 x quad GPU instances, place an agent per instance, managing
389       4 workers per agent.
390    4. Use 1 x 8 GPU instance, place an agent per instance, managing
391       8 workers per agent.
392
393    Usage
394    ::
395
396     group_result = agent.run()
397      if group_result.is_failed():
398        # workers failed
399        failure = group_result.failures[0]
400        logger.exception("worker 0 failed with exit code : %s", failure.exit_code)
401      else:
402        return group_result.return_values[0] # return rank 0's results
403
404    """
405
406    @abc.abstractmethod
407    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
408        """Run the agent.
409
410        Supports retrying the worker group on failures up to ``max_restarts``.
411
412        Returns:
413            The result of the execution, containing the return values or
414            failure details for each worker mapped by the worker's global rank.
415
416        Raises:
417            Exception - any other failures NOT related to worker process
418        """
419        raise NotImplementedError
420
421    @abc.abstractmethod
422    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
423        """Return the ``WorkerGroup`` for the given ``role``.
424
425        Note that the worker group is a mutable object and hence in a
426        multi-threaded/process environment it may change state.
427        Implementors are encouraged (but not required) to return
428        a defensive read-only copy.
429        """
430        raise NotImplementedError
431
432
433class SimpleElasticAgent(ElasticAgent):
434    """An ``ElasticAgent`` that manages one particular type of worker role.
435
436    An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec``
437    such as one particular type of worker role.
438    """
439
440    def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
441        self._worker_group = WorkerGroup(spec)
442        self._remaining_restarts = self._worker_group.spec.max_restarts
443        self._store = None
444        self._exit_barrier_timeout = exit_barrier_timeout
445        self._total_execution_time = 0
446
447    def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
448        return self._worker_group
449
450    @abc.abstractmethod
451    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
452        r"""Start ``worker_group.spec.local_world_size`` number of workers.
453
454        This is according to worker spec for the worker group .
455        Returns a map of ``local_rank`` to worker ``id``.
456        """
457        raise NotImplementedError
458
459    @abc.abstractmethod
460    def _stop_workers(
461        self, worker_group: WorkerGroup, is_restart: bool = False
462    ) -> None:
463        r"""Stop all workers in the given worker group.
464
465        Implementors must deal with workers in all states defined by
466        ``WorkerState``. That is, it must gracefully handle stopping
467        non-existent workers, unhealthy (stuck) workers, etc.
468        """
469        raise NotImplementedError
470
471    @abc.abstractmethod
472    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
473        r"""Check on the workers for the ``worker_group``.
474
475        This function also returns the new state of the worker group.
476        """
477        raise NotImplementedError
478
479    @abc.abstractmethod
480    def _shutdown(
481        self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
482    ) -> None:
483        """Clean up any resources that were allocated during the agent's work.
484
485        Args:
486            death_sig: Signal to send to the child process, SIGTERM is default
487        """
488        raise NotImplementedError
489
490    @prof
491    def _rendezvous(self, worker_group: WorkerGroup) -> None:
492        r"""Run rendezvous for the workers specified by the worker spec.
493
494        Assigns workers a new global rank and world size.
495        Updates the rendezvous store for the worker group.
496        """
497        spec = worker_group.spec
498
499        with self.record_duration("RENDEZVOUS"):
500            rdzv_info = spec.rdzv_handler.next_rendezvous()
501        store = rdzv_info.store
502        group_rank = rdzv_info.rank
503        group_world_size = rdzv_info.world_size
504
505        # master_addr/master_port could be explicitly overriden
506        # TODO: BC - specific to static rdzv and can be simplifed further
507        master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr
508        master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port
509
510        self._store = store
511
512        with self.record_duration("ASSIGN_WORKER_RANKS"):
513            workers = self._assign_worker_ranks(
514                store, group_rank, group_world_size, spec
515            )
516        worker_group.workers = workers
517        worker_group.store = store
518        worker_group.group_rank = group_rank
519        worker_group.group_world_size = group_world_size
520        worker_group.master_addr = master_addr
521        worker_group.master_port = master_port
522
523        restart_count = spec.max_restarts - self._remaining_restarts
524
525        logger.info(
526            "[%(role)s] Rendezvous complete for workers. Result:\n"
527            "  restart_count=%(restart_count)s\n"
528            "  master_addr=%(master_addr)s\n"
529            "  master_port=%(master_port)s\n"
530            "  group_rank=%(group_rank)s\n"
531            "  group_world_size=%(group_world_size)s\n"
532            "  local_ranks=%(local_ranks)s\n"
533            "  role_ranks=%(role_ranks)s\n"
534            "  global_ranks=%(global_ranks)s\n"
535            "  role_world_sizes=%(role_world_sizes)s\n"
536            "  global_world_sizes=%(global_world_sizes)s\n",
537            {
538                "role": spec.role,
539                "restart_count": restart_count,
540                "master_addr": master_addr,
541                "master_port": master_port,
542                "group_rank": group_rank,
543                "group_world_size": group_world_size,
544                "local_ranks": [worker.local_rank for worker in workers],
545                "role_ranks": [worker.role_rank for worker in workers],
546                "global_ranks": [worker.global_rank for worker in workers],
547                "role_world_sizes": [worker.role_world_size for worker in workers],
548                "global_world_sizes": [worker.world_size for worker in workers],
549            },
550        )
551
552    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
553    #  `torch.distributed.elastic.metrics.prof`.
554    @prof
555    def _assign_worker_ranks(
556        self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
557    ) -> List[Worker]:
558        """Determine proper ranks for worker processes.
559
560        The rank assignment is done according to the following algorithm:
561
562        1. Each agent writes its configuration(group_rank, group_world_size
563           , num_workers) to the common store.
564        2. The rank 0 agent reads all the role_info from the store and
565           determines each agents worker ranks.
566        3. Determine the global rank: the global rank of the workers is computed
567           by cumulative sum of the local_world_size for all workers in front of it.
568           For efficiency reasons each worker is assigned a base global rank
569           such that it's workers are in the range [base_global_rank,
570           base_global_rank + local_world_size).
571        4. Determine the role rank: The role rank is determined using the algorithms
572           in the point 3 with the exception that the ranks are calculated with
573           respect to the role name.
574        5. The rank 0 agent writes the assigned ranks to the store.
575        6. Each agent reads the assigned ranks from the store.
576
577        Time complexity: each worker O(1), rank0 O(n), overall O(n)
578        """
579
580        ROLE_INFO_PREFIX = "torchelastic/role_info/"
581        ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"
582
583        agent_role_info = _RoleInstanceInfo(
584            spec.role, group_rank, spec.local_world_size
585        )
586        store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
587
588        # tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
589        if group_rank == 0:
590            role_infos_bytes = store.multi_get(
591                [f"torchelastic/role_info/{i}" for i in range(group_world_size)]
592            )
593            role_infos = [
594                _RoleInstanceInfo.deserialize(info_bytes)
595                for info_bytes in role_infos_bytes
596            ]
597
598            role_sizes = defaultdict(lambda: 0)
599            global_size = 0
600            for role_info in role_infos:
601                role_sizes[role_info.role] += role_info.local_world_size
602                global_size += role_info.local_world_size
603
604            base_global_rank = 0
605            role_ranks = defaultdict(lambda: 0)
606
607            keys = []
608            values = []
609            for i, role_info in enumerate(role_infos):
610                keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
611                values.append(
612                    json.dumps(
613                        [
614                            base_global_rank,
615                            global_size,
616                            role_ranks[role_info.role],
617                            role_sizes[role_info.role],
618                        ]
619                    )
620                )
621
622                base_global_rank += role_info.local_world_size
623                role_ranks[role_info.role] += role_info.local_world_size
624
625            store.multi_set(keys, values)
626
627        # get will block until the data is available in the store.
628        (
629            base_global_rank,
630            global_world_size,
631            base_role_rank,
632            role_world_size,
633        ) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}"))
634
635        workers = []
636        for local_rank in range(spec.local_world_size):
637            worker = Worker(
638                local_rank=local_rank,
639                global_rank=base_global_rank + local_rank,
640                role_rank=base_role_rank + local_rank,
641                world_size=global_world_size,
642                role_world_size=role_world_size,
643            )
644            workers.append(worker)
645        return workers
646
647    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
648    #  `torch.distributed.elastic.metrics.prof`.
649    @prof
650    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
651        r"""Start a fresh set of workers for the worker_group.
652
653        Essentially, a rendezvous followed by a ``start_workers``.
654        The caller should first call ``_stop_workers()`` to stop running workers
655        prior to calling this method.
656
657        Optimistically sets the state of the worker group that
658        just started as ``HEALTHY`` and delegates the actual monitoring
659        of state to ``_monitor_workers()`` method
660        """
661        role = worker_group.spec.role
662        logger.info("[%s] Rendezvous'ing worker group", role)
663
664        # TODO after stopping workers, wait at least monitor_interval*2 for
665        # workers on different nodes to fail on a collective op before waiting
666        # on the rdzv barrier, this way we ensure that nodes enter rdzv
667        # at around the same time and reduce false positive rdzv timeout errors
668        self._rendezvous(worker_group)
669
670        logger.info("[%s] Starting worker group", role)
671        worker_ids = self._start_workers(worker_group)
672        for local_rank, w_id in worker_ids.items():
673            worker = worker_group.workers[local_rank]
674            worker.id = w_id
675
676        worker_group.state = WorkerState.HEALTHY
677
678    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
679    #  `torch.distributed.elastic.metrics.prof`.
680    @prof
681    def _restart_workers(self, worker_group: WorkerGroup) -> None:
682        """Restart (stops, rendezvous, starts) all local workers in the group."""
683        role = worker_group.spec.role
684        logger.info("[%s] Stopping worker group", role)
685        self._stop_workers(worker_group, is_restart=True)
686        worker_group.state = WorkerState.STOPPED
687        self._initialize_workers(worker_group)
688
689    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
690    #  `torch.distributed.elastic.metrics.prof`.
691    @prof
692    def run(self, role: str = DEFAULT_ROLE) -> RunResult:
693        start_time = time.monotonic()
694        shutdown_called: bool = False
695        try:
696            result = self._invoke_run(role)
697            self._total_execution_time = int(time.monotonic() - start_time)
698            self._record_metrics(result)
699            self._record_worker_events(result)
700            return result
701        except RendezvousGracefulExitError as e:
702            logger.info("Rendezvous gracefully exited: %s", e)
703        except SignalException as e:
704            logger.warning("Received %s death signal, shutting down workers", e.sigval)
705            self._shutdown(e.sigval)
706            shutdown_called = True
707            raise
708        finally:
709            if not shutdown_called:
710                self._shutdown()
711            # record the execution time in case there were any exceptions during run.
712            self._total_execution_time = int(time.monotonic() - start_time)
713
714    def get_event_failed(self) -> Event:
715        return self._construct_event(
716            state="FAILED",
717            source=EventSource.AGENT,
718            raw_error=traceback.format_exc(),
719        )
720
721    def get_event_succeeded(self) -> Event:
722        return self._construct_event(
723            state="SUCCEEDED",
724            source=EventSource.AGENT,
725        )
726
727    def _record_worker_events(self, result: RunResult) -> None:
728        for worker in self._worker_group.workers:
729            failure = result.failures.get(worker.global_rank)
730            state: str = self._get_worker_state(worker, result)
731            raw_error = json.dumps(failure.error_file_data) if failure else None
732            record(self._construct_event(state, EventSource.WORKER, worker, raw_error))
733
734    def _get_worker_state(self, worker: Worker, result: RunResult) -> str:
735        failure = result.failures.get(worker.global_rank)
736        if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure:
737            # The worker got terminated by the torchelastic agent via SIGTERM signal
738            return "TERMINATED"
739        elif failure or worker.global_rank in result.return_values:
740            return result.state.value
741        else:
742            raise ValueError(f"Unknown worker: {worker.global_rank}")
743
744    @contextmanager
745    def record_duration(self, state: str):
746        start_time = time.perf_counter()
747        try:
748            yield
749        finally:
750            end_time = time.perf_counter()
751            duration_ms = (end_time - start_time) * 1000
752            record(
753                self._construct_event(
754                    state=state, source=EventSource.AGENT, duration_ms=duration_ms
755                )
756            )
757
758    def _construct_event(
759        self,
760        state: str,
761        source: EventSource,
762        worker: Optional[Worker] = None,
763        raw_error: Optional[str] = None,
764        duration_ms: Optional[float] = None,
765    ) -> Event:
766        wg = self._worker_group
767        spec = wg.spec
768        md = {
769            "group_world_size": wg.group_world_size,
770            "entry_point": spec.get_entrypoint_name(),
771        }
772        if worker:
773            md["local_rank"] = (worker.local_rank,)
774            md["role_rank"] = (worker.role_rank,)
775            md["role_world_size"] = (worker.role_world_size,)
776            global_rank = worker.global_rank
777            worker_id = str(worker.id)
778        else:
779            global_rank = None
780            worker_id = None
781        md_str = json.dumps(md)
782        metadata = {
783            "run_id": spec.rdzv_handler.get_run_id(),
784            "global_rank": global_rank,
785            "group_rank": wg.group_rank,
786            "worker_id": worker_id,
787            "role": spec.role,
788            "hostname": _get_fq_hostname(),
789            "state": state,
790            "total_run_time": self._total_execution_time,
791            "rdzv_backend": spec.rdzv_handler.get_backend(),
792            "raw_error": raw_error,
793            "metadata": md_str,
794            "agent_restarts": spec.max_restarts - self._remaining_restarts,
795            "duration_ms": duration_ms,
796        }
797        return Event(
798            f"torchelastic.worker.status.{state}", source=source, metadata=metadata
799        )
800
801    def _record_metrics(self, group_results: RunResult):
802        is_failed = group_results.is_failed()
803        self._record_flakiness_metric(is_failed)
804        spec = self._worker_group.spec
805        restarts_happened = self._remaining_restarts != spec.max_restarts
806        put_metric(f"workers.{spec.role}.run_total", 1)
807        self._record_metric_with_condition(
808            "run_success_with_retries", not is_failed and restarts_happened
809        )
810        self._record_metric_with_condition(
811            "run_success_no_retries", not is_failed and not restarts_happened
812        )
813        self._record_metric_with_condition(
814            "run_failed_with_retries", is_failed and restarts_happened
815        )
816        self._record_metric_with_condition(
817            "run_failed_no_retries", is_failed and not restarts_happened
818        )
819
820    def _record_metric_with_condition(self, metric_name, condition):
821        spec = self._worker_group.spec
822        if condition:
823            put_metric(f"workers.{spec.role}.{metric_name}", 1)
824        else:
825            put_metric(f"workers.{spec.role}.{metric_name}", 0)
826
827    def _record_flakiness_metric(self, is_failed: bool = False):
828        if is_failed:
829            flakiness = 100.0
830        else:
831            spec = self._worker_group.spec
832            flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (
833                spec.max_restarts + 1
834            )
835        spec = self._worker_group.spec
836
837        put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
838
839    def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
840        # NOTE: currently only works for a single role
841
842        spec = self._worker_group.spec
843        role = spec.role
844
845        logger.info(
846            "[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
847        )
848
849        self._initialize_workers(self._worker_group)
850        monitor_interval = spec.monitor_interval
851        rdzv_handler = spec.rdzv_handler
852
853        while True:
854            assert self._worker_group.state != WorkerState.INIT
855            time.sleep(monitor_interval)
856            run_result = self._monitor_workers(self._worker_group)
857            state = run_result.state
858            self._worker_group.state = state
859
860            put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
861            put_metric(f"workers.{role}.{state.name.lower()}", 1)
862
863            if state == WorkerState.SUCCEEDED:
864                logger.info(
865                    "[%s] worker group successfully finished."
866                    " Waiting %s seconds for other agents to finish.",
867                    role,
868                    self._exit_barrier_timeout,
869                )
870                self._exit_barrier()
871                return run_result
872            elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
873                if self._remaining_restarts > 0:
874                    logger.info(
875                        "[%s] Worker group %s. "
876                        "%s/%s attempts left;"
877                        " will restart worker group",
878                        role,
879                        state.name,
880                        self._remaining_restarts,
881                        spec.max_restarts,
882                    )
883                    self._remaining_restarts -= 1
884                    self._restart_workers(self._worker_group)
885                else:
886                    self._stop_workers(self._worker_group)
887                    self._worker_group.state = WorkerState.FAILED
888                    return run_result
889            elif state == WorkerState.HEALTHY:
890                # membership changes do not count as retries
891                num_nodes_waiting = rdzv_handler.num_nodes_waiting()
892                group_rank = self._worker_group.group_rank
893                if num_nodes_waiting > 0:
894                    logger.info(
895                        "[%s] Detected %s "
896                        "new nodes from group_rank=%s; "
897                        "will restart worker group",
898                        role,
899                        num_nodes_waiting,
900                        group_rank,
901                    )
902                    self._restart_workers(self._worker_group)
903            else:
904                raise Exception(  # noqa: TRY002
905                    f"[{role}] Worker group in {state.name} state"
906                )
907
908    def _exit_barrier(self):
909        """
910        Define a barrier that keeps the agent process alive until all workers finish.
911
912        Wait for ``exit_barrier_timeout`` seconds for all agents to finish
913        executing their local workers (either successfully or not). This
914        acts as a safety guard against user scripts that terminate at different
915        times.
916        """
917        logger.info(
918            "Local worker group finished (%s). "
919            "Waiting %s seconds for other agents to finish",
920            self._worker_group.state,
921            self._exit_barrier_timeout,
922        )
923        start = time.time()
924        try:
925            store_util.barrier(
926                store=self._store,
927                world_size=self._worker_group.group_world_size,
928                key_prefix=_TERMINAL_STATE_SYNC_ID,
929                barrier_timeout=self._exit_barrier_timeout,
930            )
931            logger.info(
932                "Done waiting for other agents. Elapsed: %s seconds",
933                time.time() - start,
934            )
935        except SignalException as e:
936            logger.warning("Got termination signal: %s", e.sigval)
937            raise
938        except Exception:
939            logger.exception(
940                "Error waiting on exit barrier. Elapsed: %s seconds",
941                time.time() - start,
942            )
943