xref: /aosp_15_r20/external/pytorch/torch/_inductor/compile_worker/subproc_pool.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import itertools
4import logging
5import multiprocessing
6import os
7import pickle
8import struct
9import subprocess
10import sys
11import threading
12import typing
13from concurrent.futures import Future, ProcessPoolExecutor
14from typing import Any, Callable, Dict
15
16from torch._inductor.compile_worker.watchdog import _async_compile_initializer
17
18log = logging.getLogger(__name__)
19
20
21def _pack_msg(job_id, length):
22    return struct.pack("nn", job_id, length)
23
24
25def _unpack_msg(data):
26    if not data:
27        return -1, -1
28    return struct.unpack("nn", data)
29
30
31msg_bytes = len(_pack_msg(0, 0))
32
33
34def _send_msg(write_pipe, job_id, job_data=b""):
35    length = len(job_data)
36    write_pipe.write(_pack_msg(job_id, length))
37    if length > 0:
38        write_pipe.write(job_data)
39    write_pipe.flush()
40
41
42def _recv_msg(read_pipe):
43    job_id, length = _unpack_msg(read_pipe.read(msg_bytes))
44    data = read_pipe.read(length) if length > 0 else b""
45    return job_id, data
46
47
48class SubprocPool:
49    """
50    Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in
51    a subprocess.Popen() to try to avoid issues with forking/spawning
52    """
53
54    def __init__(self, nprocs: int):
55        entry = os.path.join(os.path.dirname(__file__), "__main__.py")
56
57        subproc_read_fd, write_fd = os.pipe()
58        read_fd, subproc_write_fd = os.pipe()
59        self.write_pipe = os.fdopen(write_fd, "wb")
60        self.read_pipe = os.fdopen(read_fd, "rb")
61
62        cmd = [
63            sys.executable,
64            entry,
65            f"--workers={nprocs}",
66            f"--parent={os.getpid()}",
67            f"--read-fd={str(subproc_read_fd)}",
68            f"--write-fd={str(subproc_write_fd)}",
69        ]
70        self.process = subprocess.Popen(
71            cmd,
72            env={
73                **os.environ,
74                # We need to set the PYTHONPATH so the subprocess can find torch.
75                "PYTHONPATH": os.pathsep.join(sys.path),
76                # We don't want to re-warm the pool when the subprocess imports
77                # torch._inductor.codecache since the warming process is what
78                # creates the SubprocPool in the first place.
79                "TORCH_WARM_POOL": "0",
80            },
81            pass_fds=(subproc_read_fd, subproc_write_fd),
82        )
83        self.write_lock = threading.Lock()
84        self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
85
86        self.futures_lock = threading.Lock()
87        self.pending_futures: Dict[int, Future[Any]] = {}
88        self.job_id_count = itertools.count()
89
90        self.running = True
91
92        # Start thread last to ensure all member variables are initialized
93        # before any access.
94        self.read_thread.start()
95
96    def submit(self, job_fn: Callable[..., Any], *args):
97        if args:
98            job_fn = functools.partial(job_fn, *args)
99        job_data = pickle.dumps(job_fn, pickle.HIGHEST_PROTOCOL)
100        future: Future[Any]
101        with self.futures_lock:
102            job_id = next(self.job_id_count)
103            self.pending_futures[job_id] = future = Future()
104        future.set_running_or_notify_cancel()
105        with self.write_lock:
106            if not self.running:
107                raise RuntimeError("submit() on closed pool")
108            _send_msg(self.write_pipe, job_id, job_data)
109        return future
110
111    def _read_thread(self):
112        try:
113            while True:
114                job_id, data = _recv_msg(self.read_pipe)
115                if job_id < 0:
116                    if self.running:
117                        log.warning("SubprocPool unclean exit")
118                    self.read_pipe.close()
119                    return
120                result = pickle.loads(data)
121                with self.futures_lock:
122                    if not self.running:
123                        return
124                    if isinstance(result, Exception):
125                        self.pending_futures[job_id].set_exception(result)
126                    else:
127                        self.pending_futures[job_id].set_result(result)
128                    del self.pending_futures[job_id]
129        except Exception:
130            log.exception("failure in SubprocPool._read_thread")
131
132    def shutdown(self):
133        try:
134            with self.write_lock:
135                if not self.running:
136                    return
137                self.running = False
138                _send_msg(self.write_pipe, -1)
139                self.write_pipe.close()
140            self.process.wait(10)
141        except OSError as e:
142            log.warning("Ignored OSError in pool shutdown:  %s", e)
143        finally:
144            with self.futures_lock:
145                for future in self.pending_futures.values():
146                    if not future.cancel():
147                        future.set_exception(RuntimeError("SubprocPool closed"))
148                self.pending_futures.clear()
149
150
151class SubprocMain:
152    """Communicates with a SubprocPool in the parent process, called by __main__.py"""
153
154    def __init__(self, nprocs, read_pipe, write_pipe):
155        self.read_pipe = read_pipe
156        self.write_pipe = write_pipe
157        self.write_lock = threading.Lock()
158        self.pool = ProcessPoolExecutor(
159            nprocs,
160            mp_context=multiprocessing.get_context("fork"),
161            initializer=functools.partial(_async_compile_initializer, os.getpid()),
162        )
163        multiprocessing.util.Finalize(
164            None, self.pool.shutdown, exitpriority=sys.maxsize
165        )
166        self.running = True
167        _warm_process_pool(self.pool, nprocs)
168
169    def main(self):
170        while True:
171            job_id, data = _recv_msg(self.read_pipe)
172            if job_id < 0:
173                return self._shutdown()
174            self.submit(job_id, data)
175
176    def _shutdown(self):
177        with self.write_lock:
178            self.running = False
179            try:
180                _send_msg(self.write_pipe, -1)
181                self.write_pipe.close()
182            except BrokenPipeError:
183                pass  # parent process already shutdown
184            self.read_pipe.close()
185        self.pool.shutdown()
186
187    def submit(self, job_id, data):
188        future = self.pool.submit(functools.partial(SubprocMain.do_job, data))
189
190        def callback(_):
191            if not self.running:
192                return
193            try:
194                result = future.result()
195            except Exception as e:
196                log.exception("Error in subprocess")
197                result = pickle.dumps(e, pickle.HIGHEST_PROTOCOL)
198            assert isinstance(result, bytes)
199            with self.write_lock:
200                if self.running:
201                    _send_msg(self.write_pipe, job_id, result)
202
203        future.add_done_callback(callback)
204
205    @staticmethod
206    def do_job(data):
207        # do the pickle/unpickle in the sub-subproc
208        job = pickle.loads(data)
209        result = job()
210        return pickle.dumps(result, pickle.HIGHEST_PROTOCOL)
211
212
213AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool]
214
215
216def _warm_process_pool(pool: AnyPool, n: int):
217    if isinstance(pool, SubprocPool):
218        return  # no need
219    assert isinstance(pool, ProcessPoolExecutor)
220
221    # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
222    # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
223
224    # Examples:
225    # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
226    # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
227
228    # So we want to start the workers early when it is still cheap, and also to allow the workers to get
229    # ready before we have work for them.
230
231    # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
232    # But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
233
234    # We force them to start here with some YOLOing of the internal methods.
235
236    # TODO(masnesral): Are these still relevant?
237    if hasattr(pool, "_start_queue_management_thread"):
238        pool._start_queue_management_thread()
239    else:
240        for _ in range(n):
241            pool._adjust_process_count()
242        if hasattr(pool, "_start_executor_manager_thread"):
243            pool._start_executor_manager_thread()
244
245
246class TestException(RuntimeError):
247    pass
248
249
250def raise_testexc():
251    raise TestException
252