xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/timer/file_based_local_timer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import io
9import json
10import os
11import select
12import signal
13import sys
14import threading
15import time
16from typing import Callable, Dict, List, Optional, Set, Tuple
17
18from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
19from torch.distributed.elastic.timer.debug_info_logging import (
20    log_debug_info_for_expired_timers,
21)
22from torch.distributed.elastic.utils.logging import get_logger
23
24
25__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
26
27logger = get_logger(__name__)
28
29
30class FileTimerRequest(TimerRequest):
31    """
32    Data object representing a countdown timer acquisition and release
33    that is used between the ``FileTimerClient`` and ``FileTimerServer``.
34    A negative ``expiration_time`` should be interpreted as a "release"
35    request.
36    ``signal`` is the signal to reap the worker process from the server
37    process.
38    """
39
40    __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
41
42    def __init__(
43        self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0
44    ) -> None:
45        self.version = 1
46        self.worker_pid = worker_pid
47        self.scope_id = scope_id
48        self.expiration_time = expiration_time
49        self.signal = signal
50
51    def __eq__(self, other) -> bool:
52        if isinstance(other, FileTimerRequest):
53            return (
54                self.version == other.version
55                and self.worker_pid == other.worker_pid
56                and self.scope_id == other.scope_id
57                and self.expiration_time == other.expiration_time
58                and self.signal == other.signal
59            )
60        return False
61
62    def to_json(self) -> str:
63        return json.dumps(
64            {
65                "version": self.version,
66                "pid": self.worker_pid,
67                "scope_id": self.scope_id,
68                "expiration_time": self.expiration_time,
69                "signal": self.signal,
70            },
71        )
72
73
74class FileTimerClient(TimerClient):
75    """
76    Client side of ``FileTimerServer``. This client is meant to be used
77    on the same host that the ``FileTimerServer`` is running on and uses
78    pid to uniquely identify a worker.
79    This client uses a named_pipe to send timer requests to the
80    ``FileTimerServer``. This client is a producer while the
81    ``FileTimerServer`` is a consumer. Multiple clients can work with
82    the same ``FileTimerServer``.
83
84    Args:
85
86        file_path: str, the path of a FIFO special file. ``FileTimerServer``
87                        must have created it by calling os.mkfifo().
88
89        signal: signal, the signal to use to kill the process. Using a
90                        negative or zero signal will not kill the process.
91    """
92
93    def __init__(
94        self,
95        file_path: str,
96        signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT),  # type: ignore[attr-defined]
97    ) -> None:
98        super().__init__()
99        self._file_path = file_path
100        self.signal = signal
101
102    def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
103        try:
104            fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
105            return os.fdopen(fd, "wt")
106        except Exception:
107            return None
108
109    def _send_request(self, request: FileTimerRequest) -> None:
110        # The server may have crashed or may haven't started yet.
111        # In such case, calling open() in blocking model blocks the client.
112        # To avoid such issue, open it in non-blocking mode, and an OSError will
113        # be raised if the server is not there.
114        file = self._open_non_blocking()
115        if file is None:
116            raise BrokenPipeError(
117                "Could not send the FileTimerRequest because FileTimerServer is not available."
118            )
119        with file:
120            json_request = request.to_json()
121            # Write request with no greater than select.PIPE_BUF is guarantee to be atomic.
122            if len(json_request) > select.PIPE_BUF:
123                raise RuntimeError(
124                    f"FileTimerRequest larger than {select.PIPE_BUF} bytes "
125                    f"is not supported: {json_request}"
126                )
127            file.write(json_request + "\n")
128
129    def acquire(self, scope_id: str, expiration_time: float) -> None:
130        self._send_request(
131            request=FileTimerRequest(
132                worker_pid=os.getpid(),
133                scope_id=scope_id,
134                expiration_time=expiration_time,
135                signal=self.signal,
136            ),
137        )
138
139    def release(self, scope_id: str) -> None:
140        self._send_request(
141            request=FileTimerRequest(
142                worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0
143            ),
144        )
145
146
147class FileTimerServer:
148    """
149    Server that works with ``FileTimerClient``. Clients are expected to be
150    running on the same host as the process that is running this server.
151    Each host in the job is expected to start its own timer server locally
152    and each server instance manages timers for local workers (running on
153    processes on the same host).
154
155    Args:
156
157        file_path: str, the path of a FIFO special file to be created.
158
159        max_interval: float, max interval in seconds for each watchdog loop.
160
161        daemon: bool, running the watchdog thread in daemon mode or not.
162                      A daemon thread will not block a process to stop.
163        log_event: Callable[[Dict[str, str]], None], an optional callback for
164                logging the events in JSON format.
165    """
166
167    def __init__(
168        self,
169        file_path: str,
170        run_id: str,
171        max_interval: float = 10,
172        daemon: bool = True,
173        log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
174    ) -> None:
175        self._file_path = file_path
176        self._run_id = run_id
177        self._max_interval = max_interval
178        self._daemon = daemon
179        self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
180        self._stop_signaled = False
181        self._watchdog_thread: Optional[threading.Thread] = None
182        if os.path.exists(self._file_path):
183            os.remove(self._file_path)
184        os.mkfifo(self._file_path)
185        # For test only. Count the number of requests received.
186        self._request_count = 0
187        # For test only. Process all requests and stop the server.
188        self._run_once = False
189        self._log_event = (
190            log_event if log_event is not None else lambda name, request: None
191        )
192        self._last_progress_time = int(time.time())
193
194    def start(self) -> None:
195        logger.info(
196            "Starting %s... max_interval=%s, daemon=%s, file_path=%s",
197            type(self).__name__,
198            self._max_interval,
199            self._daemon,
200            self._file_path,
201        )
202        self._watchdog_thread = threading.Thread(
203            target=self._watchdog_loop, daemon=self._daemon
204        )
205        logger.info("Starting watchdog thread...")
206        self._watchdog_thread.start()
207        self._log_event("watchdog started", None)
208
209    def stop(self) -> None:
210        logger.info("Stopping %s", type(self).__name__)
211        self._stop_signaled = True
212        if self._watchdog_thread:
213            logger.info("Stopping watchdog thread...")
214            self._watchdog_thread.join(self._max_interval)
215            self._watchdog_thread = None
216        else:
217            logger.info("No watchdog thread running, doing nothing")
218        if os.path.exists(self._file_path):
219            os.remove(self._file_path)
220        self._log_event("watchdog stopped", None)
221
222    def run_once(self) -> None:
223        self._run_once = True
224        if self._watchdog_thread:
225            logger.info("Stopping watchdog thread...")
226            self._watchdog_thread.join()
227            self._watchdog_thread = None
228        else:
229            logger.info("No watchdog thread running, doing nothing")
230        if os.path.exists(self._file_path):
231            os.remove(self._file_path)
232
233    @staticmethod
234    def is_process_running(pid: int):
235        """
236        function to check process is running or not
237        """
238        try:
239            # Check if the process exists and we can send signals to it
240            os.kill(pid, 0)
241            return True
242        except OSError:
243            return False
244
245    def _watchdog_loop(self) -> None:
246        # Open the pipe in blocking mode blocks the server thread.
247        # This is fine for the following reasons:
248        #  1. No client case usually does not happen.
249        #  2. We are running the watchdog loop in a separate daemon
250        #     thread, which will not block the process to stop.
251        with open(self._file_path) as fd:
252            while not self._stop_signaled:
253                try:
254                    run_once = self._run_once
255                    self._run_watchdog(fd)
256                    if run_once:
257                        break
258                    self._last_progress_time = int(time.time())
259                except Exception:
260                    logger.exception("Error running watchdog")
261
262    def _run_watchdog(self, fd: io.TextIOWrapper) -> None:
263        timer_requests = self._get_requests(fd, self._max_interval)
264        self.register_timers(timer_requests)
265        now = time.time()
266        reaped_worker_pids = set()
267
268        all_expired_timers = self.get_expired_timers(now)
269        log_debug_info_for_expired_timers(
270            self._run_id,
271            {
272                pid: self._get_scopes(expired_timers)
273                for pid, expired_timers in all_expired_timers.items()
274            },
275        )
276
277        for worker_pid, expired_timers in all_expired_timers.items():
278            logger.info(
279                "Reaping worker_pid=[%s]. Expired timers: %s",
280                worker_pid,
281                self._get_scopes(expired_timers),
282            )
283            reaped_worker_pids.add(worker_pid)
284            # In case we have multiple expired timers, we find the first timer
285            # with a valid signal (>0) in the expiration time order.
286            expired_timers.sort(key=lambda timer: timer.expiration_time)
287            signal = 0
288            expired_timer = None
289            for timer in expired_timers:
290                self._log_event("timer expired", timer)
291                if timer.signal > 0:
292                    signal = timer.signal
293                    expired_timer = timer
294                    break
295            if signal <= 0:
296                logger.info(
297                    "No signal specified with worker=[%s]. Do not reap it.", worker_pid
298                )
299                continue
300            if self._reap_worker(worker_pid, signal):
301                logger.info(
302                    "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal
303                )
304                self._log_event("kill worker process", expired_timer)
305            else:
306                logger.error(
307                    "Error reaping worker=[%s]. Will retry on next watchdog.",
308                    worker_pid,
309                )
310        self.clear_timers(reaped_worker_pids)
311
312    def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
313        return [r.scope_id for r in timer_requests]
314
315    def _get_requests(
316        self, fd: io.TextIOWrapper, max_interval: float
317    ) -> List[FileTimerRequest]:
318        start = time.time()
319        requests = []
320        while not self._stop_signaled or self._run_once:
321            # For named pipe, readline() is blocking when at least one writer opens.
322            # It returns only when flush() is called at the writer side.
323            # Note that flush() is automatically called inside close().
324            # After the last writer closes, readline() is not blocking.
325            # It will return an empty string when it's at end-of-file.
326            # Since the client side always opens the pipe, writes a message and closes
327            # the pipe immediately, the readline() call below is not blocking for long.
328            json_request = fd.readline()
329            if len(json_request) == 0:
330                if self._run_once:
331                    break
332                time.sleep(min(max_interval, 1))
333            else:
334                request = json.loads(json_request)
335                pid = request["pid"]
336                scope_id = request["scope_id"]
337                expiration_time = request["expiration_time"]
338                signal = request["signal"]
339                requests.append(
340                    FileTimerRequest(
341                        worker_pid=pid,
342                        scope_id=scope_id,
343                        expiration_time=expiration_time,
344                        signal=signal,
345                    )
346                )
347            now = time.time()
348            if now - start > max_interval:
349                break
350        return requests
351
352    def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
353        for request in timer_requests:
354            pid = request.worker_pid
355            scope_id = request.scope_id
356            expiration_time = request.expiration_time
357            self._request_count += 1
358
359            key = (pid, scope_id)
360            # negative expiration is a proxy for a release call
361            if expiration_time < 0:
362                if key in self._timers:
363                    del self._timers[key]
364            else:
365                self._timers[key] = request
366
367    def clear_timers(self, worker_pids: Set[int]) -> None:
368        for pid, scope_id in list(self._timers.keys()):
369            if pid in worker_pids or not FileTimerServer.is_process_running(pid):
370                del self._timers[(pid, scope_id)]
371
372    def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
373        # pid -> [timer_requests...]
374        expired_timers: Dict[int, List[FileTimerRequest]] = {}
375        for request in self._timers.values():
376            if request.expiration_time <= deadline:
377                expired_scopes = expired_timers.setdefault(request.worker_pid, [])
378                expired_scopes.append(request)
379        return expired_timers
380
381    def _reap_worker(self, worker_pid: int, signal: int) -> bool:
382        try:
383            os.kill(worker_pid, signal)
384            return True
385        except ProcessLookupError:
386            logger.info("Process with pid=%s does not exist. Skipping", worker_pid)
387            return True
388        except Exception:
389            logger.exception("Error terminating pid=%s", worker_pid)
390        return False
391
392    def get_last_progress_time(self) -> int:
393        return self._last_progress_time
394