xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/queue.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import io
3import multiprocessing.queues
4import pickle
5from multiprocessing.reduction import ForkingPickler
6
7
8class ConnectionWrapper:
9    """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
10
11    def __init__(self, conn):
12        self.conn = conn
13
14    def send(self, obj):
15        buf = io.BytesIO()
16        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
17        self.send_bytes(buf.getvalue())
18
19    def recv(self):
20        buf = self.recv_bytes()
21        return pickle.loads(buf)
22
23    def __getattr__(self, name):
24        if "conn" in self.__dict__:
25            return getattr(self.conn, name)
26        raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
27
28
29class Queue(multiprocessing.queues.Queue):
30    def __init__(self, *args, **kwargs):
31        super().__init__(*args, **kwargs)
32        self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
33        self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
34        self._send = self._writer.send
35        self._recv = self._reader.recv
36
37
38class SimpleQueue(multiprocessing.queues.SimpleQueue):
39    def _make_methods(self):
40        if not isinstance(self._reader, ConnectionWrapper):
41            self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
42            self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
43        super()._make_methods()  # type: ignore[misc]
44