1# mypy: allow-untyped-defs 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7import logging 8import multiprocessing as mp 9import os 10import signal 11import time 12from queue import Empty 13from typing import Any, Dict, List, Set, Tuple 14 15from .api import RequestQueue, TimerClient, TimerRequest, TimerServer 16 17 18__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] 19 20logger = logging.getLogger(__name__) 21 22 23class LocalTimerClient(TimerClient): 24 """ 25 Client side of ``LocalTimerServer``. This client is meant to be used 26 on the same host that the ``LocalTimerServer`` is running on and uses 27 pid to uniquely identify a worker. This is particularly useful in situations 28 where one spawns a subprocess (trainer) per GPU on a host with multiple 29 GPU devices. 30 """ 31 32 def __init__(self, mp_queue): 33 super().__init__() 34 self._mp_queue = mp_queue 35 36 def acquire(self, scope_id, expiration_time): 37 pid = os.getpid() 38 acquire_request = TimerRequest(pid, scope_id, expiration_time) 39 self._mp_queue.put(acquire_request) 40 41 def release(self, scope_id): 42 pid = os.getpid() 43 release_request = TimerRequest(pid, scope_id, -1) 44 self._mp_queue.put(release_request) 45 46 47class MultiprocessingRequestQueue(RequestQueue): 48 """ 49 A ``RequestQueue`` backed by python ``multiprocessing.Queue`` 50 """ 51 52 def __init__(self, mp_queue: mp.Queue): 53 super().__init__() 54 self._mp_queue = mp_queue 55 56 def size(self) -> int: 57 return self._mp_queue.qsize() 58 59 def get(self, size, timeout: float) -> List[TimerRequest]: 60 requests = [] 61 wait = timeout 62 for _ in range(0, size): 63 start = time.time() 64 65 try: 66 r = self._mp_queue.get(block=True, timeout=wait) 67 except Empty: 68 break 69 70 requests.append(r) 71 wait = wait - (time.time() - start) 72 if wait <= 0: 73 break 74 75 return requests 76 77 78class LocalTimerServer(TimerServer): 79 """ 80 Server that works with ``LocalTimerClient``. Clients are expected to be 81 subprocesses to the parent process that is running this server. Each host 82 in the job is expected to start its own timer server locally and each 83 server instance manages timers for local workers (running on processes 84 on the same host). 85 """ 86 87 def __init__( 88 self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True 89 ): 90 super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) 91 self._timers: Dict[Tuple[Any, str], TimerRequest] = {} 92 93 def register_timers(self, timer_requests: List[TimerRequest]) -> None: 94 for request in timer_requests: 95 pid = request.worker_id 96 scope_id = request.scope_id 97 expiration_time = request.expiration_time 98 99 # negative expiration is a proxy for a release call 100 if expiration_time < 0: 101 self._timers.pop((pid, scope_id), None) 102 else: 103 self._timers[(pid, scope_id)] = request 104 105 def clear_timers(self, worker_ids: Set[int]) -> None: 106 for pid, scope_id in list(self._timers.keys()): 107 if pid in worker_ids: 108 self._timers.pop((pid, scope_id)) 109 110 def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: 111 # pid -> [timer_requests...] 112 expired_timers: Dict[Any, List[TimerRequest]] = {} 113 for request in self._timers.values(): 114 if request.expiration_time <= deadline: 115 expired_scopes = expired_timers.setdefault(request.worker_id, []) 116 expired_scopes.append(request) 117 return expired_timers 118 119 def _reap_worker(self, worker_id: int) -> bool: 120 try: 121 os.kill(worker_id, signal.SIGKILL) 122 return True 123 except ProcessLookupError: 124 logger.info("Process with pid=%s does not exist. Skipping", worker_id) 125 return True 126 except Exception: 127 logger.exception("Error terminating pid=%s", worker_id) 128 return False 129