xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/agent/server/local_elastic_agent.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10
11import json
12import os
13import signal
14import socket
15import time
16import uuid
17from string import Template
18from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
19
20import torch.distributed.elastic.timer as timer
21from torch.distributed.elastic import events
22from torch.distributed.elastic.agent.server.api import (
23    RunResult,
24    SimpleElasticAgent,
25    WorkerGroup,
26    WorkerSpec,
27    WorkerState,
28)
29from torch.distributed.elastic.agent.server.health_check_server import (
30    create_healthcheck_server,
31    HealthCheckServer,
32)
33from torch.distributed.elastic.metrics.api import prof
34from torch.distributed.elastic.multiprocessing import (
35    LogsSpecs,
36    PContext,
37    start_processes,
38)
39from torch.distributed.elastic.utils import macros
40from torch.distributed.elastic.utils.logging import get_logger
41
42
43if TYPE_CHECKING:
44    from torch.distributed.elastic.events.api import EventMetadataValue
45
46logger = get_logger(__name__)
47
48__all__ = [
49    "LocalElasticAgent",
50    "TORCHELASTIC_ENABLE_FILE_TIMER",
51    "TORCHELASTIC_TIMER_FILE",
52    "TORCHELASTIC_HEALTH_CHECK_PORT",
53]
54
55TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
56TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT"
57TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
58
59
60class LocalElasticAgent(SimpleElasticAgent):
61    """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
62
63    This agent is deployed per host and is configured to spawn ``n`` workers.
64    When using GPUs, ``n`` maps to the number of GPUs available on the host.
65
66    The local agent does not communicate to other local agents deployed on
67    other hosts, even if the workers may communicate inter-host. The worker id
68    is interpreted to be a local process. The agent starts and stops all worker
69    processes as a single unit.
70
71
72    The worker function and argument passed to the worker function must be
73    python multiprocessing compatible. To pass multiprocessing data structures
74    to the workers you may create the data structure in the same multiprocessing
75    context as the specified ``start_method`` and pass it as a function argument.
76
77    The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
78    for other agents to finish. This acts as a safety net to handle cases where
79    workers finish at different times, to prevent agents from viewing workers
80    that finished early as a scale-down event. It is strongly advised that the
81    user code deal with ensuring that workers are terminated in a synchronous
82    manner rather than relying on the exit_barrier_timeout.
83
84    A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an
85    environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has
86    been defined in the ```LocalElasticAgent``` process.
87    Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE```
88    can be set with a unique file name for the named pipe. If the environment
89    variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent```
90    will internally create a unique file name and set it to the environment
91    variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will
92    be propagated to the worker processes to allow them to connect to the same
93    named pipe that ```LocalElasticAgent``` uses.
94
95    Logs are written to the specified log directory. Each log line will be by default
96    prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``).
97    Log prefixes can be customized by passing a `template string
98    <https://docs.python.org/3/library/string.html#template-strings>`_ as the
99    ``log_line_prefix_template`` argument.
100    The following macros (identifiers) are substituted at runtime:
101    ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with
102    global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``.
103
104
105    Example launching function
106
107    ::
108
109        def trainer(args) -> str:
110            return "do train"
111
112        def main():
113            start_method="spawn"
114            shared_queue= multiprocessing.get_context(start_method).Queue()
115            spec = WorkerSpec(
116                        role="trainer",
117                        local_world_size=nproc_per_process,
118                        entrypoint=trainer,
119                        args=("foobar",),
120                        ...<OTHER_PARAMS...>)
121            agent = LocalElasticAgent(spec, start_method)
122            results = agent.run()
123
124            if results.is_failed():
125                print("trainer failed")
126            else:
127                print(f"rank 0 return value: {results.return_values[0]}")
128                # prints -> rank 0 return value: do train
129
130    Example launching binary
131
132    ::
133
134        def main():
135            spec = WorkerSpec(
136                        role="trainer",
137                        local_world_size=nproc_per_process,
138                        entrypoint="/usr/local/bin/trainer",
139                        args=("--trainer-args", "foobar"),
140                        ...<OTHER_PARAMS...>)
141            agent = LocalElasticAgent(spec)
142            results = agent.run()
143
144            if not results.is_failed():
145                print("binary launches do not have return values")
146
147    """
148
149    def __init__(
150        self,
151        spec: WorkerSpec,
152        logs_specs: LogsSpecs,
153        start_method="spawn",
154        exit_barrier_timeout: float = 300,
155        log_line_prefix_template: Optional[str] = None,
156    ):
157        super().__init__(spec, exit_barrier_timeout)
158        self._start_method = start_method
159        self._pcontext: Optional[PContext] = None
160        self._rdzv_handler = spec.rdzv_handler
161        self._log_line_prefix_template = log_line_prefix_template
162        self._worker_watchdog: Optional[timer.FileTimerServer] = None
163        self._logs_specs = logs_specs
164        self._health_check_server: Optional[HealthCheckServer] = None
165
166    def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
167        enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
168        watchdog_enabled = os.getenv(enable_watchdog_env_name)
169        watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
170        watchdog_file_path = os.getenv(watchdog_file_env_name)
171        if watchdog_enabled is not None and str(watchdog_enabled) == "1":
172            if watchdog_file_path is None:
173                watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
174            logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
175            if not envs:
176                logger.warning(
177                    "Empty envs variables, using empty run_id for FileTimerServer"
178                )
179                run_id = ""
180            else:
181                run_id = envs[0]["TORCHELASTIC_RUN_ID"]
182            self._worker_watchdog = timer.FileTimerServer(
183                file_path=watchdog_file_path,
184                run_id=run_id,
185                max_interval=0.1,
186                daemon=True,
187                log_event=self._log_watchdog_event,
188            )
189            self._worker_watchdog.start()
190            logger.info("FileTimerServer started")
191        else:
192            logger.info(
193                "Environment variable '%s' not found. Do not start FileTimerServer.",
194                enable_watchdog_env_name,
195            )
196        # Propagate the watchdog file env to worker processes
197        if watchdog_file_path is not None:
198            for worker_env in envs.values():
199                worker_env[watchdog_file_env_name] = watchdog_file_path
200
201    @staticmethod
202    def _get_current_time_secs() -> int:
203        return int(time.time())
204
205    def _setup_healthcheck(self) -> None:
206        healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
207        healthcheck_port = os.getenv(healthcheck_port_env_name)
208        if healthcheck_port is not None:
209            logger.info(
210                "Found healthcheck port %s: %s",
211                healthcheck_port_env_name,
212                healthcheck_port,
213            )
214            if self._worker_watchdog is None:
215                logger.info(
216                    "FileTimerServer doesn't exist, using current time as dummy callback"
217                )
218                alive_callback = LocalElasticAgent._get_current_time_secs
219            else:
220                alive_callback = self._worker_watchdog.get_last_progress_time
221
222            self._health_check_server = create_healthcheck_server(
223                alive_callback=alive_callback,
224                port=int(healthcheck_port),
225                timeout=60,
226            )
227            self._health_check_server.start()
228        else:
229            logger.info(
230                "Environment variable '%s' not found. Do not start health check.",
231                healthcheck_port_env_name,
232            )
233
234    def _get_fq_hostname(self) -> str:
235        return socket.getfqdn(socket.gethostname())
236
237    def _log_watchdog_event(
238        self,
239        name: str,
240        request: Optional[timer.FileTimerRequest],
241    ) -> None:
242        wg = self._worker_group
243        spec = wg.spec
244        md = {"watchdog_event": name}
245        if request is not None:
246            md["worker_pid"] = str(request.worker_pid)
247            md["scope_id"] = request.scope_id
248            md["expiration_time"] = str(request.expiration_time)
249            md["signal"] = str(request.signal)
250        md_str = json.dumps(md)
251        state = "RUNNING"
252        metadata: Dict[str, EventMetadataValue] = {
253            "run_id": spec.rdzv_handler.get_run_id(),
254            "global_rank": None,
255            "group_rank": wg.group_rank,
256            "worker_id": None,
257            "role": spec.role,
258            "hostname": self._get_fq_hostname(),
259            "state": state,
260            "total_run_time": self._total_execution_time,
261            "rdzv_backend": spec.rdzv_handler.get_backend(),
262            "raw_error": None,
263            "metadata": md_str,
264            "agent_restarts": spec.max_restarts - self._remaining_restarts,
265        }
266        # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.
267        #       The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.
268        event = events.Event(
269            name=name, source=events.EventSource.AGENT, metadata=metadata
270        )
271        events.record(event)
272
273    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
274    #  `torch.distributed.elastic.metrics.prof`.
275    @prof
276    def _stop_workers(
277        self, worker_group: WorkerGroup, is_restart: bool = False
278    ) -> None:
279        self._shutdown(is_restart=is_restart)
280
281    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
282    #  `torch.distributed.elastic.metrics.prof`.
283    @prof
284    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
285        spec = worker_group.spec
286        store = worker_group.store
287        assert store is not None
288        restart_count = spec.max_restarts - self._remaining_restarts
289
290        use_agent_store: bool = spec.rdzv_handler.use_agent_store
291        logger.info("use_agent_store: %s", use_agent_store)
292
293        args: Dict[int, Tuple] = {}
294        envs: Dict[int, Dict[str, str]] = {}
295        log_line_prefixes: Optional[Dict[int, str]] = (
296            {} if self._log_line_prefix_template else None
297        )
298        for worker in worker_group.workers:
299            local_rank = worker.local_rank
300            worker_env = {
301                "LOCAL_RANK": str(local_rank),
302                "RANK": str(worker.global_rank),
303                "GROUP_RANK": str(worker_group.group_rank),
304                "ROLE_RANK": str(worker.role_rank),
305                "ROLE_NAME": spec.role,
306                "LOCAL_WORLD_SIZE": str(spec.local_world_size),
307                "WORLD_SIZE": str(worker.world_size),
308                "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
309                "ROLE_WORLD_SIZE": str(worker.role_world_size),
310                "MASTER_ADDR": worker_group.master_addr,
311                "MASTER_PORT": str(worker_group.master_port),
312                "TORCHELASTIC_RESTART_COUNT": str(restart_count),
313                "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
314                "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
315                "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
316                "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv(
317                    "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
318                ),
319            }
320            if "OMP_NUM_THREADS" in os.environ:
321                worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
322
323            if self._log_line_prefix_template:
324                log_line_prefix = Template(
325                    self._log_line_prefix_template
326                ).safe_substitute(
327                    role_name=spec.role,
328                    rank=worker.global_rank,
329                    local_rank=local_rank,
330                )
331                log_line_prefixes[local_rank] = log_line_prefix
332
333            envs[local_rank] = worker_env
334            worker_args = list(spec.args)
335            worker_args = macros.substitute(worker_args, str(local_rank))
336            args[local_rank] = tuple(worker_args)
337
338        self._setup_local_watchdog(envs=envs)
339        self._setup_healthcheck()
340
341        assert spec.entrypoint is not None
342        assert self._logs_specs is not None
343        self._pcontext = start_processes(
344            name=spec.role,
345            entrypoint=spec.entrypoint,
346            args=args,
347            envs=envs,
348            logs_specs=self._logs_specs,
349            log_line_prefixes=log_line_prefixes,
350            start_method=self._start_method,
351        )
352
353        return self._pcontext.pids()
354
355    def _shutdown(
356        self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
357    ) -> None:
358        if self._worker_watchdog is not None:
359            self._worker_watchdog.stop()
360            self._worker_watchdog = None
361        if self._health_check_server is not None:
362            self._health_check_server.stop()
363            self._health_check_server = None
364        if self._pcontext:
365            self._pcontext.close(death_sig)
366        if not is_restart and self._rdzv_handler:
367            self._rdzv_handler.shutdown()
368
369    # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
370    #  `torch.distributed.elastic.metrics.prof`.
371    @prof
372    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
373        role = worker_group.spec.role
374        worker_pids = {w.id for w in worker_group.workers}
375        assert self._pcontext is not None
376        pc_pids = set(self._pcontext.pids().values())
377        if worker_pids != pc_pids:
378            logger.error(
379                "[%s] worker pids do not match process_context pids."
380                " Expected: %s, actual: %s",
381                role,
382                worker_pids,
383                pc_pids,
384            )
385            return RunResult(state=WorkerState.UNKNOWN)
386
387        result = self._pcontext.wait(0)
388        if result:
389            if result.is_failed():
390                # map local rank failure to global rank
391                worker_failures = {}
392                for local_rank, failure in result.failures.items():
393                    worker = worker_group.workers[local_rank]
394                    worker_failures[worker.global_rank] = failure
395                return RunResult(
396                    state=WorkerState.FAILED,
397                    failures=worker_failures,
398                )
399            else:
400                # copy ret_val_queue into a map with a global ranks
401                workers_ret_vals = {}
402                for local_rank, ret_val in result.return_values.items():
403                    worker = worker_group.workers[local_rank]
404                    workers_ret_vals[worker.global_rank] = ret_val
405                return RunResult(
406                    state=WorkerState.SUCCEEDED,
407                    return_values=workers_ret_vals,
408                )
409        else:
410            return RunResult(state=WorkerState.HEALTHY)
411