1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10 11import json 12import os 13import signal 14import socket 15import time 16import uuid 17from string import Template 18from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING 19 20import torch.distributed.elastic.timer as timer 21from torch.distributed.elastic import events 22from torch.distributed.elastic.agent.server.api import ( 23 RunResult, 24 SimpleElasticAgent, 25 WorkerGroup, 26 WorkerSpec, 27 WorkerState, 28) 29from torch.distributed.elastic.agent.server.health_check_server import ( 30 create_healthcheck_server, 31 HealthCheckServer, 32) 33from torch.distributed.elastic.metrics.api import prof 34from torch.distributed.elastic.multiprocessing import ( 35 LogsSpecs, 36 PContext, 37 start_processes, 38) 39from torch.distributed.elastic.utils import macros 40from torch.distributed.elastic.utils.logging import get_logger 41 42 43if TYPE_CHECKING: 44 from torch.distributed.elastic.events.api import EventMetadataValue 45 46logger = get_logger(__name__) 47 48__all__ = [ 49 "LocalElasticAgent", 50 "TORCHELASTIC_ENABLE_FILE_TIMER", 51 "TORCHELASTIC_TIMER_FILE", 52 "TORCHELASTIC_HEALTH_CHECK_PORT", 53] 54 55TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER" 56TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT" 57TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" 58 59 60class LocalElasticAgent(SimpleElasticAgent): 61 """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. 62 63 This agent is deployed per host and is configured to spawn ``n`` workers. 64 When using GPUs, ``n`` maps to the number of GPUs available on the host. 65 66 The local agent does not communicate to other local agents deployed on 67 other hosts, even if the workers may communicate inter-host. The worker id 68 is interpreted to be a local process. The agent starts and stops all worker 69 processes as a single unit. 70 71 72 The worker function and argument passed to the worker function must be 73 python multiprocessing compatible. To pass multiprocessing data structures 74 to the workers you may create the data structure in the same multiprocessing 75 context as the specified ``start_method`` and pass it as a function argument. 76 77 The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait 78 for other agents to finish. This acts as a safety net to handle cases where 79 workers finish at different times, to prevent agents from viewing workers 80 that finished early as a scale-down event. It is strongly advised that the 81 user code deal with ensuring that workers are terminated in a synchronous 82 manner rather than relying on the exit_barrier_timeout. 83 84 A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an 85 environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has 86 been defined in the ```LocalElasticAgent``` process. 87 Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE``` 88 can be set with a unique file name for the named pipe. If the environment 89 variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent``` 90 will internally create a unique file name and set it to the environment 91 variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will 92 be propagated to the worker processes to allow them to connect to the same 93 named pipe that ```LocalElasticAgent``` uses. 94 95 Logs are written to the specified log directory. Each log line will be by default 96 prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``). 97 Log prefixes can be customized by passing a `template string 98 <https://docs.python.org/3/library/string.html#template-strings>`_ as the 99 ``log_line_prefix_template`` argument. 100 The following macros (identifiers) are substituted at runtime: 101 ``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with 102 global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``. 103 104 105 Example launching function 106 107 :: 108 109 def trainer(args) -> str: 110 return "do train" 111 112 def main(): 113 start_method="spawn" 114 shared_queue= multiprocessing.get_context(start_method).Queue() 115 spec = WorkerSpec( 116 role="trainer", 117 local_world_size=nproc_per_process, 118 entrypoint=trainer, 119 args=("foobar",), 120 ...<OTHER_PARAMS...>) 121 agent = LocalElasticAgent(spec, start_method) 122 results = agent.run() 123 124 if results.is_failed(): 125 print("trainer failed") 126 else: 127 print(f"rank 0 return value: {results.return_values[0]}") 128 # prints -> rank 0 return value: do train 129 130 Example launching binary 131 132 :: 133 134 def main(): 135 spec = WorkerSpec( 136 role="trainer", 137 local_world_size=nproc_per_process, 138 entrypoint="/usr/local/bin/trainer", 139 args=("--trainer-args", "foobar"), 140 ...<OTHER_PARAMS...>) 141 agent = LocalElasticAgent(spec) 142 results = agent.run() 143 144 if not results.is_failed(): 145 print("binary launches do not have return values") 146 147 """ 148 149 def __init__( 150 self, 151 spec: WorkerSpec, 152 logs_specs: LogsSpecs, 153 start_method="spawn", 154 exit_barrier_timeout: float = 300, 155 log_line_prefix_template: Optional[str] = None, 156 ): 157 super().__init__(spec, exit_barrier_timeout) 158 self._start_method = start_method 159 self._pcontext: Optional[PContext] = None 160 self._rdzv_handler = spec.rdzv_handler 161 self._log_line_prefix_template = log_line_prefix_template 162 self._worker_watchdog: Optional[timer.FileTimerServer] = None 163 self._logs_specs = logs_specs 164 self._health_check_server: Optional[HealthCheckServer] = None 165 166 def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None: 167 enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER 168 watchdog_enabled = os.getenv(enable_watchdog_env_name) 169 watchdog_file_env_name = TORCHELASTIC_TIMER_FILE 170 watchdog_file_path = os.getenv(watchdog_file_env_name) 171 if watchdog_enabled is not None and str(watchdog_enabled) == "1": 172 if watchdog_file_path is None: 173 watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4()) 174 logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path) 175 if not envs: 176 logger.warning( 177 "Empty envs variables, using empty run_id for FileTimerServer" 178 ) 179 run_id = "" 180 else: 181 run_id = envs[0]["TORCHELASTIC_RUN_ID"] 182 self._worker_watchdog = timer.FileTimerServer( 183 file_path=watchdog_file_path, 184 run_id=run_id, 185 max_interval=0.1, 186 daemon=True, 187 log_event=self._log_watchdog_event, 188 ) 189 self._worker_watchdog.start() 190 logger.info("FileTimerServer started") 191 else: 192 logger.info( 193 "Environment variable '%s' not found. Do not start FileTimerServer.", 194 enable_watchdog_env_name, 195 ) 196 # Propagate the watchdog file env to worker processes 197 if watchdog_file_path is not None: 198 for worker_env in envs.values(): 199 worker_env[watchdog_file_env_name] = watchdog_file_path 200 201 @staticmethod 202 def _get_current_time_secs() -> int: 203 return int(time.time()) 204 205 def _setup_healthcheck(self) -> None: 206 healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT 207 healthcheck_port = os.getenv(healthcheck_port_env_name) 208 if healthcheck_port is not None: 209 logger.info( 210 "Found healthcheck port %s: %s", 211 healthcheck_port_env_name, 212 healthcheck_port, 213 ) 214 if self._worker_watchdog is None: 215 logger.info( 216 "FileTimerServer doesn't exist, using current time as dummy callback" 217 ) 218 alive_callback = LocalElasticAgent._get_current_time_secs 219 else: 220 alive_callback = self._worker_watchdog.get_last_progress_time 221 222 self._health_check_server = create_healthcheck_server( 223 alive_callback=alive_callback, 224 port=int(healthcheck_port), 225 timeout=60, 226 ) 227 self._health_check_server.start() 228 else: 229 logger.info( 230 "Environment variable '%s' not found. Do not start health check.", 231 healthcheck_port_env_name, 232 ) 233 234 def _get_fq_hostname(self) -> str: 235 return socket.getfqdn(socket.gethostname()) 236 237 def _log_watchdog_event( 238 self, 239 name: str, 240 request: Optional[timer.FileTimerRequest], 241 ) -> None: 242 wg = self._worker_group 243 spec = wg.spec 244 md = {"watchdog_event": name} 245 if request is not None: 246 md["worker_pid"] = str(request.worker_pid) 247 md["scope_id"] = request.scope_id 248 md["expiration_time"] = str(request.expiration_time) 249 md["signal"] = str(request.signal) 250 md_str = json.dumps(md) 251 state = "RUNNING" 252 metadata: Dict[str, EventMetadataValue] = { 253 "run_id": spec.rdzv_handler.get_run_id(), 254 "global_rank": None, 255 "group_rank": wg.group_rank, 256 "worker_id": None, 257 "role": spec.role, 258 "hostname": self._get_fq_hostname(), 259 "state": state, 260 "total_run_time": self._total_execution_time, 261 "rdzv_backend": spec.rdzv_handler.get_backend(), 262 "raw_error": None, 263 "metadata": md_str, 264 "agent_restarts": spec.max_restarts - self._remaining_restarts, 265 } 266 # Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later. 267 # The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry. 268 event = events.Event( 269 name=name, source=events.EventSource.AGENT, metadata=metadata 270 ) 271 events.record(event) 272 273 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 274 # `torch.distributed.elastic.metrics.prof`. 275 @prof 276 def _stop_workers( 277 self, worker_group: WorkerGroup, is_restart: bool = False 278 ) -> None: 279 self._shutdown(is_restart=is_restart) 280 281 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 282 # `torch.distributed.elastic.metrics.prof`. 283 @prof 284 def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: 285 spec = worker_group.spec 286 store = worker_group.store 287 assert store is not None 288 restart_count = spec.max_restarts - self._remaining_restarts 289 290 use_agent_store: bool = spec.rdzv_handler.use_agent_store 291 logger.info("use_agent_store: %s", use_agent_store) 292 293 args: Dict[int, Tuple] = {} 294 envs: Dict[int, Dict[str, str]] = {} 295 log_line_prefixes: Optional[Dict[int, str]] = ( 296 {} if self._log_line_prefix_template else None 297 ) 298 for worker in worker_group.workers: 299 local_rank = worker.local_rank 300 worker_env = { 301 "LOCAL_RANK": str(local_rank), 302 "RANK": str(worker.global_rank), 303 "GROUP_RANK": str(worker_group.group_rank), 304 "ROLE_RANK": str(worker.role_rank), 305 "ROLE_NAME": spec.role, 306 "LOCAL_WORLD_SIZE": str(spec.local_world_size), 307 "WORLD_SIZE": str(worker.world_size), 308 "GROUP_WORLD_SIZE": str(worker_group.group_world_size), 309 "ROLE_WORLD_SIZE": str(worker.role_world_size), 310 "MASTER_ADDR": worker_group.master_addr, 311 "MASTER_PORT": str(worker_group.master_port), 312 "TORCHELASTIC_RESTART_COUNT": str(restart_count), 313 "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts), 314 "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(), 315 "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store), 316 "TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv( 317 "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) 318 ), 319 } 320 if "OMP_NUM_THREADS" in os.environ: 321 worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] 322 323 if self._log_line_prefix_template: 324 log_line_prefix = Template( 325 self._log_line_prefix_template 326 ).safe_substitute( 327 role_name=spec.role, 328 rank=worker.global_rank, 329 local_rank=local_rank, 330 ) 331 log_line_prefixes[local_rank] = log_line_prefix 332 333 envs[local_rank] = worker_env 334 worker_args = list(spec.args) 335 worker_args = macros.substitute(worker_args, str(local_rank)) 336 args[local_rank] = tuple(worker_args) 337 338 self._setup_local_watchdog(envs=envs) 339 self._setup_healthcheck() 340 341 assert spec.entrypoint is not None 342 assert self._logs_specs is not None 343 self._pcontext = start_processes( 344 name=spec.role, 345 entrypoint=spec.entrypoint, 346 args=args, 347 envs=envs, 348 logs_specs=self._logs_specs, 349 log_line_prefixes=log_line_prefixes, 350 start_method=self._start_method, 351 ) 352 353 return self._pcontext.pids() 354 355 def _shutdown( 356 self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False 357 ) -> None: 358 if self._worker_watchdog is not None: 359 self._worker_watchdog.stop() 360 self._worker_watchdog = None 361 if self._health_check_server is not None: 362 self._health_check_server.stop() 363 self._health_check_server = None 364 if self._pcontext: 365 self._pcontext.close(death_sig) 366 if not is_restart and self._rdzv_handler: 367 self._rdzv_handler.shutdown() 368 369 # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 370 # `torch.distributed.elastic.metrics.prof`. 371 @prof 372 def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: 373 role = worker_group.spec.role 374 worker_pids = {w.id for w in worker_group.workers} 375 assert self._pcontext is not None 376 pc_pids = set(self._pcontext.pids().values()) 377 if worker_pids != pc_pids: 378 logger.error( 379 "[%s] worker pids do not match process_context pids." 380 " Expected: %s, actual: %s", 381 role, 382 worker_pids, 383 pc_pids, 384 ) 385 return RunResult(state=WorkerState.UNKNOWN) 386 387 result = self._pcontext.wait(0) 388 if result: 389 if result.is_failed(): 390 # map local rank failure to global rank 391 worker_failures = {} 392 for local_rank, failure in result.failures.items(): 393 worker = worker_group.workers[local_rank] 394 worker_failures[worker.global_rank] = failure 395 return RunResult( 396 state=WorkerState.FAILED, 397 failures=worker_failures, 398 ) 399 else: 400 # copy ret_val_queue into a map with a global ranks 401 workers_ret_vals = {} 402 for local_rank, ret_val in result.return_values.items(): 403 worker = worker_group.workers[local_rank] 404 workers_ret_vals[worker.global_rank] = ret_val 405 return RunResult( 406 state=WorkerState.SUCCEEDED, 407 return_values=workers_ret_vals, 408 ) 409 else: 410 return RunResult(state=WorkerState.HEALTHY) 411