xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/pool.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport multiprocessing.pool
2*da0073e9SAndroid Build Coastguard Workerimport multiprocessing.util as util
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom .queue import SimpleQueue
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerdef clean_worker(*args, **kwargs):
8*da0073e9SAndroid Build Coastguard Worker    import gc
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker    multiprocessing.pool.worker(*args, **kwargs)
11*da0073e9SAndroid Build Coastguard Worker    # Regular multiprocessing workers don't fully clean up after themselves,
12*da0073e9SAndroid Build Coastguard Worker    # so we have to explicitly trigger garbage collection to make sure that all
13*da0073e9SAndroid Build Coastguard Worker    # destructors are called...
14*da0073e9SAndroid Build Coastguard Worker    gc.collect()
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass Pool(multiprocessing.pool.Pool):
18*da0073e9SAndroid Build Coastguard Worker    """Pool implementation which uses our version of SimpleQueue.
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    This lets us pass tensors in shared memory across processes instead of
21*da0073e9SAndroid Build Coastguard Worker    serializing the underlying data.
22*da0073e9SAndroid Build Coastguard Worker    """
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def _setup_queues(self):
25*da0073e9SAndroid Build Coastguard Worker        self._inqueue = SimpleQueue()
26*da0073e9SAndroid Build Coastguard Worker        self._outqueue = SimpleQueue()
27*da0073e9SAndroid Build Coastguard Worker        self._quick_put = self._inqueue._writer.send
28*da0073e9SAndroid Build Coastguard Worker        self._quick_get = self._outqueue._reader.recv
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def _repopulate_pool(self):
31*da0073e9SAndroid Build Coastguard Worker        """Increase the number of pool processes to the specified number.
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        Bring the number of pool processes up to the specified number, for use after
34*da0073e9SAndroid Build Coastguard Worker        reaping workers which have exited.
35*da0073e9SAndroid Build Coastguard Worker        """
36*da0073e9SAndroid Build Coastguard Worker        for i in range(self._processes - len(self._pool)):
37*da0073e9SAndroid Build Coastguard Worker            # changed worker -> clean_worker
38*da0073e9SAndroid Build Coastguard Worker            args = (
39*da0073e9SAndroid Build Coastguard Worker                self._inqueue,
40*da0073e9SAndroid Build Coastguard Worker                self._outqueue,
41*da0073e9SAndroid Build Coastguard Worker                self._initializer,
42*da0073e9SAndroid Build Coastguard Worker                self._initargs,
43*da0073e9SAndroid Build Coastguard Worker                self._maxtasksperchild,
44*da0073e9SAndroid Build Coastguard Worker            )
45*da0073e9SAndroid Build Coastguard Worker            if hasattr(self, "_wrap_exception"):
46*da0073e9SAndroid Build Coastguard Worker                args += (self._wrap_exception,)
47*da0073e9SAndroid Build Coastguard Worker            w = self.Process(target=clean_worker, args=args)
48*da0073e9SAndroid Build Coastguard Worker            self._pool.append(w)
49*da0073e9SAndroid Build Coastguard Worker            w.name = w.name.replace("Process", "PoolWorker")
50*da0073e9SAndroid Build Coastguard Worker            w.daemon = True
51*da0073e9SAndroid Build Coastguard Worker            w.start()
52*da0073e9SAndroid Build Coastguard Worker            util.debug("added worker")
53