1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import io 9import json 10import os 11import select 12import signal 13import sys 14import threading 15import time 16from typing import Callable, Dict, List, Optional, Set, Tuple 17 18from torch.distributed.elastic.timer.api import TimerClient, TimerRequest 19from torch.distributed.elastic.timer.debug_info_logging import ( 20 log_debug_info_for_expired_timers, 21) 22from torch.distributed.elastic.utils.logging import get_logger 23 24 25__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] 26 27logger = get_logger(__name__) 28 29 30class FileTimerRequest(TimerRequest): 31 """ 32 Data object representing a countdown timer acquisition and release 33 that is used between the ``FileTimerClient`` and ``FileTimerServer``. 34 A negative ``expiration_time`` should be interpreted as a "release" 35 request. 36 ``signal`` is the signal to reap the worker process from the server 37 process. 38 """ 39 40 __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] 41 42 def __init__( 43 self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 44 ) -> None: 45 self.version = 1 46 self.worker_pid = worker_pid 47 self.scope_id = scope_id 48 self.expiration_time = expiration_time 49 self.signal = signal 50 51 def __eq__(self, other) -> bool: 52 if isinstance(other, FileTimerRequest): 53 return ( 54 self.version == other.version 55 and self.worker_pid == other.worker_pid 56 and self.scope_id == other.scope_id 57 and self.expiration_time == other.expiration_time 58 and self.signal == other.signal 59 ) 60 return False 61 62 def to_json(self) -> str: 63 return json.dumps( 64 { 65 "version": self.version, 66 "pid": self.worker_pid, 67 "scope_id": self.scope_id, 68 "expiration_time": self.expiration_time, 69 "signal": self.signal, 70 }, 71 ) 72 73 74class FileTimerClient(TimerClient): 75 """ 76 Client side of ``FileTimerServer``. This client is meant to be used 77 on the same host that the ``FileTimerServer`` is running on and uses 78 pid to uniquely identify a worker. 79 This client uses a named_pipe to send timer requests to the 80 ``FileTimerServer``. This client is a producer while the 81 ``FileTimerServer`` is a consumer. Multiple clients can work with 82 the same ``FileTimerServer``. 83 84 Args: 85 86 file_path: str, the path of a FIFO special file. ``FileTimerServer`` 87 must have created it by calling os.mkfifo(). 88 89 signal: signal, the signal to use to kill the process. Using a 90 negative or zero signal will not kill the process. 91 """ 92 93 def __init__( 94 self, 95 file_path: str, 96 signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] 97 ) -> None: 98 super().__init__() 99 self._file_path = file_path 100 self.signal = signal 101 102 def _open_non_blocking(self) -> Optional[io.TextIOWrapper]: 103 try: 104 fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) 105 return os.fdopen(fd, "wt") 106 except Exception: 107 return None 108 109 def _send_request(self, request: FileTimerRequest) -> None: 110 # The server may have crashed or may haven't started yet. 111 # In such case, calling open() in blocking model blocks the client. 112 # To avoid such issue, open it in non-blocking mode, and an OSError will 113 # be raised if the server is not there. 114 file = self._open_non_blocking() 115 if file is None: 116 raise BrokenPipeError( 117 "Could not send the FileTimerRequest because FileTimerServer is not available." 118 ) 119 with file: 120 json_request = request.to_json() 121 # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. 122 if len(json_request) > select.PIPE_BUF: 123 raise RuntimeError( 124 f"FileTimerRequest larger than {select.PIPE_BUF} bytes " 125 f"is not supported: {json_request}" 126 ) 127 file.write(json_request + "\n") 128 129 def acquire(self, scope_id: str, expiration_time: float) -> None: 130 self._send_request( 131 request=FileTimerRequest( 132 worker_pid=os.getpid(), 133 scope_id=scope_id, 134 expiration_time=expiration_time, 135 signal=self.signal, 136 ), 137 ) 138 139 def release(self, scope_id: str) -> None: 140 self._send_request( 141 request=FileTimerRequest( 142 worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 143 ), 144 ) 145 146 147class FileTimerServer: 148 """ 149 Server that works with ``FileTimerClient``. Clients are expected to be 150 running on the same host as the process that is running this server. 151 Each host in the job is expected to start its own timer server locally 152 and each server instance manages timers for local workers (running on 153 processes on the same host). 154 155 Args: 156 157 file_path: str, the path of a FIFO special file to be created. 158 159 max_interval: float, max interval in seconds for each watchdog loop. 160 161 daemon: bool, running the watchdog thread in daemon mode or not. 162 A daemon thread will not block a process to stop. 163 log_event: Callable[[Dict[str, str]], None], an optional callback for 164 logging the events in JSON format. 165 """ 166 167 def __init__( 168 self, 169 file_path: str, 170 run_id: str, 171 max_interval: float = 10, 172 daemon: bool = True, 173 log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, 174 ) -> None: 175 self._file_path = file_path 176 self._run_id = run_id 177 self._max_interval = max_interval 178 self._daemon = daemon 179 self._timers: Dict[Tuple[int, str], FileTimerRequest] = {} 180 self._stop_signaled = False 181 self._watchdog_thread: Optional[threading.Thread] = None 182 if os.path.exists(self._file_path): 183 os.remove(self._file_path) 184 os.mkfifo(self._file_path) 185 # For test only. Count the number of requests received. 186 self._request_count = 0 187 # For test only. Process all requests and stop the server. 188 self._run_once = False 189 self._log_event = ( 190 log_event if log_event is not None else lambda name, request: None 191 ) 192 self._last_progress_time = int(time.time()) 193 194 def start(self) -> None: 195 logger.info( 196 "Starting %s... max_interval=%s, daemon=%s, file_path=%s", 197 type(self).__name__, 198 self._max_interval, 199 self._daemon, 200 self._file_path, 201 ) 202 self._watchdog_thread = threading.Thread( 203 target=self._watchdog_loop, daemon=self._daemon 204 ) 205 logger.info("Starting watchdog thread...") 206 self._watchdog_thread.start() 207 self._log_event("watchdog started", None) 208 209 def stop(self) -> None: 210 logger.info("Stopping %s", type(self).__name__) 211 self._stop_signaled = True 212 if self._watchdog_thread: 213 logger.info("Stopping watchdog thread...") 214 self._watchdog_thread.join(self._max_interval) 215 self._watchdog_thread = None 216 else: 217 logger.info("No watchdog thread running, doing nothing") 218 if os.path.exists(self._file_path): 219 os.remove(self._file_path) 220 self._log_event("watchdog stopped", None) 221 222 def run_once(self) -> None: 223 self._run_once = True 224 if self._watchdog_thread: 225 logger.info("Stopping watchdog thread...") 226 self._watchdog_thread.join() 227 self._watchdog_thread = None 228 else: 229 logger.info("No watchdog thread running, doing nothing") 230 if os.path.exists(self._file_path): 231 os.remove(self._file_path) 232 233 @staticmethod 234 def is_process_running(pid: int): 235 """ 236 function to check process is running or not 237 """ 238 try: 239 # Check if the process exists and we can send signals to it 240 os.kill(pid, 0) 241 return True 242 except OSError: 243 return False 244 245 def _watchdog_loop(self) -> None: 246 # Open the pipe in blocking mode blocks the server thread. 247 # This is fine for the following reasons: 248 # 1. No client case usually does not happen. 249 # 2. We are running the watchdog loop in a separate daemon 250 # thread, which will not block the process to stop. 251 with open(self._file_path) as fd: 252 while not self._stop_signaled: 253 try: 254 run_once = self._run_once 255 self._run_watchdog(fd) 256 if run_once: 257 break 258 self._last_progress_time = int(time.time()) 259 except Exception: 260 logger.exception("Error running watchdog") 261 262 def _run_watchdog(self, fd: io.TextIOWrapper) -> None: 263 timer_requests = self._get_requests(fd, self._max_interval) 264 self.register_timers(timer_requests) 265 now = time.time() 266 reaped_worker_pids = set() 267 268 all_expired_timers = self.get_expired_timers(now) 269 log_debug_info_for_expired_timers( 270 self._run_id, 271 { 272 pid: self._get_scopes(expired_timers) 273 for pid, expired_timers in all_expired_timers.items() 274 }, 275 ) 276 277 for worker_pid, expired_timers in all_expired_timers.items(): 278 logger.info( 279 "Reaping worker_pid=[%s]. Expired timers: %s", 280 worker_pid, 281 self._get_scopes(expired_timers), 282 ) 283 reaped_worker_pids.add(worker_pid) 284 # In case we have multiple expired timers, we find the first timer 285 # with a valid signal (>0) in the expiration time order. 286 expired_timers.sort(key=lambda timer: timer.expiration_time) 287 signal = 0 288 expired_timer = None 289 for timer in expired_timers: 290 self._log_event("timer expired", timer) 291 if timer.signal > 0: 292 signal = timer.signal 293 expired_timer = timer 294 break 295 if signal <= 0: 296 logger.info( 297 "No signal specified with worker=[%s]. Do not reap it.", worker_pid 298 ) 299 continue 300 if self._reap_worker(worker_pid, signal): 301 logger.info( 302 "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal 303 ) 304 self._log_event("kill worker process", expired_timer) 305 else: 306 logger.error( 307 "Error reaping worker=[%s]. Will retry on next watchdog.", 308 worker_pid, 309 ) 310 self.clear_timers(reaped_worker_pids) 311 312 def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]: 313 return [r.scope_id for r in timer_requests] 314 315 def _get_requests( 316 self, fd: io.TextIOWrapper, max_interval: float 317 ) -> List[FileTimerRequest]: 318 start = time.time() 319 requests = [] 320 while not self._stop_signaled or self._run_once: 321 # For named pipe, readline() is blocking when at least one writer opens. 322 # It returns only when flush() is called at the writer side. 323 # Note that flush() is automatically called inside close(). 324 # After the last writer closes, readline() is not blocking. 325 # It will return an empty string when it's at end-of-file. 326 # Since the client side always opens the pipe, writes a message and closes 327 # the pipe immediately, the readline() call below is not blocking for long. 328 json_request = fd.readline() 329 if len(json_request) == 0: 330 if self._run_once: 331 break 332 time.sleep(min(max_interval, 1)) 333 else: 334 request = json.loads(json_request) 335 pid = request["pid"] 336 scope_id = request["scope_id"] 337 expiration_time = request["expiration_time"] 338 signal = request["signal"] 339 requests.append( 340 FileTimerRequest( 341 worker_pid=pid, 342 scope_id=scope_id, 343 expiration_time=expiration_time, 344 signal=signal, 345 ) 346 ) 347 now = time.time() 348 if now - start > max_interval: 349 break 350 return requests 351 352 def register_timers(self, timer_requests: List[FileTimerRequest]) -> None: 353 for request in timer_requests: 354 pid = request.worker_pid 355 scope_id = request.scope_id 356 expiration_time = request.expiration_time 357 self._request_count += 1 358 359 key = (pid, scope_id) 360 # negative expiration is a proxy for a release call 361 if expiration_time < 0: 362 if key in self._timers: 363 del self._timers[key] 364 else: 365 self._timers[key] = request 366 367 def clear_timers(self, worker_pids: Set[int]) -> None: 368 for pid, scope_id in list(self._timers.keys()): 369 if pid in worker_pids or not FileTimerServer.is_process_running(pid): 370 del self._timers[(pid, scope_id)] 371 372 def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]: 373 # pid -> [timer_requests...] 374 expired_timers: Dict[int, List[FileTimerRequest]] = {} 375 for request in self._timers.values(): 376 if request.expiration_time <= deadline: 377 expired_scopes = expired_timers.setdefault(request.worker_pid, []) 378 expired_scopes.append(request) 379 return expired_timers 380 381 def _reap_worker(self, worker_pid: int, signal: int) -> bool: 382 try: 383 os.kill(worker_pid, signal) 384 return True 385 except ProcessLookupError: 386 logger.info("Process with pid=%s does not exist. Skipping", worker_pid) 387 return True 388 except Exception: 389 logger.exception("Error terminating pid=%s", worker_pid) 390 return False 391 392 def get_last_progress_time(self) -> int: 393 return self._last_progress_time 394