xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/timer/local_timer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Copyright (c) Facebook, Inc. and its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7import logging
8import multiprocessing as mp
9import os
10import signal
11import time
12from queue import Empty
13from typing import Any, Dict, List, Set, Tuple
14
15from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
16
17
18__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"]
19
20logger = logging.getLogger(__name__)
21
22
23class LocalTimerClient(TimerClient):
24    """
25    Client side of ``LocalTimerServer``. This client is meant to be used
26    on the same host that the ``LocalTimerServer`` is running on and uses
27    pid to uniquely identify a worker. This is particularly useful in situations
28    where one spawns a subprocess (trainer) per GPU on a host with multiple
29    GPU devices.
30    """
31
32    def __init__(self, mp_queue):
33        super().__init__()
34        self._mp_queue = mp_queue
35
36    def acquire(self, scope_id, expiration_time):
37        pid = os.getpid()
38        acquire_request = TimerRequest(pid, scope_id, expiration_time)
39        self._mp_queue.put(acquire_request)
40
41    def release(self, scope_id):
42        pid = os.getpid()
43        release_request = TimerRequest(pid, scope_id, -1)
44        self._mp_queue.put(release_request)
45
46
47class MultiprocessingRequestQueue(RequestQueue):
48    """
49    A ``RequestQueue`` backed by python ``multiprocessing.Queue``
50    """
51
52    def __init__(self, mp_queue: mp.Queue):
53        super().__init__()
54        self._mp_queue = mp_queue
55
56    def size(self) -> int:
57        return self._mp_queue.qsize()
58
59    def get(self, size, timeout: float) -> List[TimerRequest]:
60        requests = []
61        wait = timeout
62        for _ in range(0, size):
63            start = time.time()
64
65            try:
66                r = self._mp_queue.get(block=True, timeout=wait)
67            except Empty:
68                break
69
70            requests.append(r)
71            wait = wait - (time.time() - start)
72            if wait <= 0:
73                break
74
75        return requests
76
77
78class LocalTimerServer(TimerServer):
79    """
80    Server that works with ``LocalTimerClient``. Clients are expected to be
81    subprocesses to the parent process that is running this server. Each host
82    in the job is expected to start its own timer server locally and each
83    server instance manages timers for local workers (running on processes
84    on the same host).
85    """
86
87    def __init__(
88        self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
89    ):
90        super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
91        self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
92
93    def register_timers(self, timer_requests: List[TimerRequest]) -> None:
94        for request in timer_requests:
95            pid = request.worker_id
96            scope_id = request.scope_id
97            expiration_time = request.expiration_time
98
99            # negative expiration is a proxy for a release call
100            if expiration_time < 0:
101                self._timers.pop((pid, scope_id), None)
102            else:
103                self._timers[(pid, scope_id)] = request
104
105    def clear_timers(self, worker_ids: Set[int]) -> None:
106        for pid, scope_id in list(self._timers.keys()):
107            if pid in worker_ids:
108                self._timers.pop((pid, scope_id))
109
110    def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
111        # pid -> [timer_requests...]
112        expired_timers: Dict[Any, List[TimerRequest]] = {}
113        for request in self._timers.values():
114            if request.expiration_time <= deadline:
115                expired_scopes = expired_timers.setdefault(request.worker_id, [])
116                expired_scopes.append(request)
117        return expired_timers
118
119    def _reap_worker(self, worker_id: int) -> bool:
120        try:
121            os.kill(worker_id, signal.SIGKILL)
122            return True
123        except ProcessLookupError:
124            logger.info("Process with pid=%s does not exist. Skipping", worker_id)
125            return True
126        except Exception:
127            logger.exception("Error terminating pid=%s", worker_id)
128        return False
129