1# mypy: allow-untyped-defs 2import functools 3import itertools 4import logging 5import multiprocessing 6import os 7import pickle 8import struct 9import subprocess 10import sys 11import threading 12import typing 13from concurrent.futures import Future, ProcessPoolExecutor 14from typing import Any, Callable, Dict 15 16from torch._inductor.compile_worker.watchdog import _async_compile_initializer 17 18log = logging.getLogger(__name__) 19 20 21def _pack_msg(job_id, length): 22 return struct.pack("nn", job_id, length) 23 24 25def _unpack_msg(data): 26 if not data: 27 return -1, -1 28 return struct.unpack("nn", data) 29 30 31msg_bytes = len(_pack_msg(0, 0)) 32 33 34def _send_msg(write_pipe, job_id, job_data=b""): 35 length = len(job_data) 36 write_pipe.write(_pack_msg(job_id, length)) 37 if length > 0: 38 write_pipe.write(job_data) 39 write_pipe.flush() 40 41 42def _recv_msg(read_pipe): 43 job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) 44 data = read_pipe.read(length) if length > 0 else b"" 45 return job_id, data 46 47 48class SubprocPool: 49 """ 50 Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in 51 a subprocess.Popen() to try to avoid issues with forking/spawning 52 """ 53 54 def __init__(self, nprocs: int): 55 entry = os.path.join(os.path.dirname(__file__), "__main__.py") 56 57 subproc_read_fd, write_fd = os.pipe() 58 read_fd, subproc_write_fd = os.pipe() 59 self.write_pipe = os.fdopen(write_fd, "wb") 60 self.read_pipe = os.fdopen(read_fd, "rb") 61 62 cmd = [ 63 sys.executable, 64 entry, 65 f"--workers={nprocs}", 66 f"--parent={os.getpid()}", 67 f"--read-fd={str(subproc_read_fd)}", 68 f"--write-fd={str(subproc_write_fd)}", 69 ] 70 self.process = subprocess.Popen( 71 cmd, 72 env={ 73 **os.environ, 74 # We need to set the PYTHONPATH so the subprocess can find torch. 75 "PYTHONPATH": os.pathsep.join(sys.path), 76 # We don't want to re-warm the pool when the subprocess imports 77 # torch._inductor.codecache since the warming process is what 78 # creates the SubprocPool in the first place. 79 "TORCH_WARM_POOL": "0", 80 }, 81 pass_fds=(subproc_read_fd, subproc_write_fd), 82 ) 83 self.write_lock = threading.Lock() 84 self.read_thread = threading.Thread(target=self._read_thread, daemon=True) 85 86 self.futures_lock = threading.Lock() 87 self.pending_futures: Dict[int, Future[Any]] = {} 88 self.job_id_count = itertools.count() 89 90 self.running = True 91 92 # Start thread last to ensure all member variables are initialized 93 # before any access. 94 self.read_thread.start() 95 96 def submit(self, job_fn: Callable[..., Any], *args): 97 if args: 98 job_fn = functools.partial(job_fn, *args) 99 job_data = pickle.dumps(job_fn, pickle.HIGHEST_PROTOCOL) 100 future: Future[Any] 101 with self.futures_lock: 102 job_id = next(self.job_id_count) 103 self.pending_futures[job_id] = future = Future() 104 future.set_running_or_notify_cancel() 105 with self.write_lock: 106 if not self.running: 107 raise RuntimeError("submit() on closed pool") 108 _send_msg(self.write_pipe, job_id, job_data) 109 return future 110 111 def _read_thread(self): 112 try: 113 while True: 114 job_id, data = _recv_msg(self.read_pipe) 115 if job_id < 0: 116 if self.running: 117 log.warning("SubprocPool unclean exit") 118 self.read_pipe.close() 119 return 120 result = pickle.loads(data) 121 with self.futures_lock: 122 if not self.running: 123 return 124 if isinstance(result, Exception): 125 self.pending_futures[job_id].set_exception(result) 126 else: 127 self.pending_futures[job_id].set_result(result) 128 del self.pending_futures[job_id] 129 except Exception: 130 log.exception("failure in SubprocPool._read_thread") 131 132 def shutdown(self): 133 try: 134 with self.write_lock: 135 if not self.running: 136 return 137 self.running = False 138 _send_msg(self.write_pipe, -1) 139 self.write_pipe.close() 140 self.process.wait(10) 141 except OSError as e: 142 log.warning("Ignored OSError in pool shutdown: %s", e) 143 finally: 144 with self.futures_lock: 145 for future in self.pending_futures.values(): 146 if not future.cancel(): 147 future.set_exception(RuntimeError("SubprocPool closed")) 148 self.pending_futures.clear() 149 150 151class SubprocMain: 152 """Communicates with a SubprocPool in the parent process, called by __main__.py""" 153 154 def __init__(self, nprocs, read_pipe, write_pipe): 155 self.read_pipe = read_pipe 156 self.write_pipe = write_pipe 157 self.write_lock = threading.Lock() 158 self.pool = ProcessPoolExecutor( 159 nprocs, 160 mp_context=multiprocessing.get_context("fork"), 161 initializer=functools.partial(_async_compile_initializer, os.getpid()), 162 ) 163 multiprocessing.util.Finalize( 164 None, self.pool.shutdown, exitpriority=sys.maxsize 165 ) 166 self.running = True 167 _warm_process_pool(self.pool, nprocs) 168 169 def main(self): 170 while True: 171 job_id, data = _recv_msg(self.read_pipe) 172 if job_id < 0: 173 return self._shutdown() 174 self.submit(job_id, data) 175 176 def _shutdown(self): 177 with self.write_lock: 178 self.running = False 179 try: 180 _send_msg(self.write_pipe, -1) 181 self.write_pipe.close() 182 except BrokenPipeError: 183 pass # parent process already shutdown 184 self.read_pipe.close() 185 self.pool.shutdown() 186 187 def submit(self, job_id, data): 188 future = self.pool.submit(functools.partial(SubprocMain.do_job, data)) 189 190 def callback(_): 191 if not self.running: 192 return 193 try: 194 result = future.result() 195 except Exception as e: 196 log.exception("Error in subprocess") 197 result = pickle.dumps(e, pickle.HIGHEST_PROTOCOL) 198 assert isinstance(result, bytes) 199 with self.write_lock: 200 if self.running: 201 _send_msg(self.write_pipe, job_id, result) 202 203 future.add_done_callback(callback) 204 205 @staticmethod 206 def do_job(data): 207 # do the pickle/unpickle in the sub-subproc 208 job = pickle.loads(data) 209 result = job() 210 return pickle.dumps(result, pickle.HIGHEST_PROTOCOL) 211 212 213AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] 214 215 216def _warm_process_pool(pool: AnyPool, n: int): 217 if isinstance(pool, SubprocPool): 218 return # no need 219 assert isinstance(pool, ProcessPoolExecutor) 220 221 # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the 222 # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. 223 224 # Examples: 225 # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup 226 # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup 227 228 # So we want to start the workers early when it is still cheap, and also to allow the workers to get 229 # ready before we have work for them. 230 231 # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. 232 # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. 233 234 # We force them to start here with some YOLOing of the internal methods. 235 236 # TODO(masnesral): Are these still relevant? 237 if hasattr(pool, "_start_queue_management_thread"): 238 pool._start_queue_management_thread() 239 else: 240 for _ in range(n): 241 pool._adjust_process_count() 242 if hasattr(pool, "_start_executor_manager_thread"): 243 pool._start_executor_manager_thread() 244 245 246class TestException(RuntimeError): 247 pass 248 249 250def raise_testexc(): 251 raise TestException 252