xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/pool.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import multiprocessing.pool
2import multiprocessing.util as util
3
4from .queue import SimpleQueue
5
6
7def clean_worker(*args, **kwargs):
8    import gc
9
10    multiprocessing.pool.worker(*args, **kwargs)
11    # Regular multiprocessing workers don't fully clean up after themselves,
12    # so we have to explicitly trigger garbage collection to make sure that all
13    # destructors are called...
14    gc.collect()
15
16
17class Pool(multiprocessing.pool.Pool):
18    """Pool implementation which uses our version of SimpleQueue.
19
20    This lets us pass tensors in shared memory across processes instead of
21    serializing the underlying data.
22    """
23
24    def _setup_queues(self):
25        self._inqueue = SimpleQueue()
26        self._outqueue = SimpleQueue()
27        self._quick_put = self._inqueue._writer.send
28        self._quick_get = self._outqueue._reader.recv
29
30    def _repopulate_pool(self):
31        """Increase the number of pool processes to the specified number.
32
33        Bring the number of pool processes up to the specified number, for use after
34        reaping workers which have exited.
35        """
36        for i in range(self._processes - len(self._pool)):
37            # changed worker -> clean_worker
38            args = (
39                self._inqueue,
40                self._outqueue,
41                self._initializer,
42                self._initargs,
43                self._maxtasksperchild,
44            )
45            if hasattr(self, "_wrap_exception"):
46                args += (self._wrap_exception,)
47            w = self.Process(target=clean_worker, args=args)
48            self._pool.append(w)
49            w.name = w.name.replace("Process", "PoolWorker")
50            w.daemon = True
51            w.start()
52            util.debug("added worker")
53