xref: /aosp_15_r20/prebuilts/build-tools/common/py3-stdlib/asyncio/streams.py (revision cda5da8d549138a6648c5ee6d7a49cf8f4a657be)
1__all__ = (
2    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
3    'open_connection', 'start_server')
4
5import collections
6import socket
7import sys
8import warnings
9import weakref
10
11if hasattr(socket, 'AF_UNIX'):
12    __all__ += ('open_unix_connection', 'start_unix_server')
13
14from . import coroutines
15from . import events
16from . import exceptions
17from . import format_helpers
18from . import protocols
19from .log import logger
20from .tasks import sleep
21
22
23_DEFAULT_LIMIT = 2 ** 16  # 64 KiB
24
25
26async def open_connection(host=None, port=None, *,
27                          limit=_DEFAULT_LIMIT, **kwds):
28    """A wrapper for create_connection() returning a (reader, writer) pair.
29
30    The reader returned is a StreamReader instance; the writer is a
31    StreamWriter instance.
32
33    The arguments are all the usual arguments to create_connection()
34    except protocol_factory; most common are positional host and port,
35    with various optional keyword arguments following.
36
37    Additional optional keyword arguments are loop (to set the event loop
38    instance to use) and limit (to set the buffer limit passed to the
39    StreamReader).
40
41    (If you want to customize the StreamReader and/or
42    StreamReaderProtocol classes, just copy the code -- there's
43    really nothing special here except some convenience.)
44    """
45    loop = events.get_running_loop()
46    reader = StreamReader(limit=limit, loop=loop)
47    protocol = StreamReaderProtocol(reader, loop=loop)
48    transport, _ = await loop.create_connection(
49        lambda: protocol, host, port, **kwds)
50    writer = StreamWriter(transport, protocol, reader, loop)
51    return reader, writer
52
53
54async def start_server(client_connected_cb, host=None, port=None, *,
55                       limit=_DEFAULT_LIMIT, **kwds):
56    """Start a socket server, call back for each client connected.
57
58    The first parameter, `client_connected_cb`, takes two parameters:
59    client_reader, client_writer.  client_reader is a StreamReader
60    object, while client_writer is a StreamWriter object.  This
61    parameter can either be a plain callback function or a coroutine;
62    if it is a coroutine, it will be automatically converted into a
63    Task.
64
65    The rest of the arguments are all the usual arguments to
66    loop.create_server() except protocol_factory; most common are
67    positional host and port, with various optional keyword arguments
68    following.  The return value is the same as loop.create_server().
69
70    Additional optional keyword arguments are loop (to set the event loop
71    instance to use) and limit (to set the buffer limit passed to the
72    StreamReader).
73
74    The return value is the same as loop.create_server(), i.e. a
75    Server object which can be used to stop the service.
76    """
77    loop = events.get_running_loop()
78
79    def factory():
80        reader = StreamReader(limit=limit, loop=loop)
81        protocol = StreamReaderProtocol(reader, client_connected_cb,
82                                        loop=loop)
83        return protocol
84
85    return await loop.create_server(factory, host, port, **kwds)
86
87
88if hasattr(socket, 'AF_UNIX'):
89    # UNIX Domain Sockets are supported on this platform
90
91    async def open_unix_connection(path=None, *,
92                                   limit=_DEFAULT_LIMIT, **kwds):
93        """Similar to `open_connection` but works with UNIX Domain Sockets."""
94        loop = events.get_running_loop()
95
96        reader = StreamReader(limit=limit, loop=loop)
97        protocol = StreamReaderProtocol(reader, loop=loop)
98        transport, _ = await loop.create_unix_connection(
99            lambda: protocol, path, **kwds)
100        writer = StreamWriter(transport, protocol, reader, loop)
101        return reader, writer
102
103    async def start_unix_server(client_connected_cb, path=None, *,
104                                limit=_DEFAULT_LIMIT, **kwds):
105        """Similar to `start_server` but works with UNIX Domain Sockets."""
106        loop = events.get_running_loop()
107
108        def factory():
109            reader = StreamReader(limit=limit, loop=loop)
110            protocol = StreamReaderProtocol(reader, client_connected_cb,
111                                            loop=loop)
112            return protocol
113
114        return await loop.create_unix_server(factory, path, **kwds)
115
116
117class FlowControlMixin(protocols.Protocol):
118    """Reusable flow control logic for StreamWriter.drain().
119
120    This implements the protocol methods pause_writing(),
121    resume_writing() and connection_lost().  If the subclass overrides
122    these it must call the super methods.
123
124    StreamWriter.drain() must wait for _drain_helper() coroutine.
125    """
126
127    def __init__(self, loop=None):
128        if loop is None:
129            self._loop = events._get_event_loop(stacklevel=4)
130        else:
131            self._loop = loop
132        self._paused = False
133        self._drain_waiters = collections.deque()
134        self._connection_lost = False
135
136    def pause_writing(self):
137        assert not self._paused
138        self._paused = True
139        if self._loop.get_debug():
140            logger.debug("%r pauses writing", self)
141
142    def resume_writing(self):
143        assert self._paused
144        self._paused = False
145        if self._loop.get_debug():
146            logger.debug("%r resumes writing", self)
147
148        for waiter in self._drain_waiters:
149            if not waiter.done():
150                waiter.set_result(None)
151
152    def connection_lost(self, exc):
153        self._connection_lost = True
154        # Wake up the writer(s) if currently paused.
155        if not self._paused:
156            return
157
158        for waiter in self._drain_waiters:
159            if not waiter.done():
160                if exc is None:
161                    waiter.set_result(None)
162                else:
163                    waiter.set_exception(exc)
164
165    async def _drain_helper(self):
166        if self._connection_lost:
167            raise ConnectionResetError('Connection lost')
168        if not self._paused:
169            return
170        waiter = self._loop.create_future()
171        self._drain_waiters.append(waiter)
172        try:
173            await waiter
174        finally:
175            self._drain_waiters.remove(waiter)
176
177    def _get_close_waiter(self, stream):
178        raise NotImplementedError
179
180
181class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
182    """Helper class to adapt between Protocol and StreamReader.
183
184    (This is a helper class instead of making StreamReader itself a
185    Protocol subclass, because the StreamReader has other potential
186    uses, and to prevent the user of the StreamReader to accidentally
187    call inappropriate methods of the protocol.)
188    """
189
190    _source_traceback = None
191
192    def __init__(self, stream_reader, client_connected_cb=None, loop=None):
193        super().__init__(loop=loop)
194        if stream_reader is not None:
195            self._stream_reader_wr = weakref.ref(stream_reader)
196            self._source_traceback = stream_reader._source_traceback
197        else:
198            self._stream_reader_wr = None
199        if client_connected_cb is not None:
200            # This is a stream created by the `create_server()` function.
201            # Keep a strong reference to the reader until a connection
202            # is established.
203            self._strong_reader = stream_reader
204        self._reject_connection = False
205        self._stream_writer = None
206        self._task = None
207        self._transport = None
208        self._client_connected_cb = client_connected_cb
209        self._over_ssl = False
210        self._closed = self._loop.create_future()
211
212    @property
213    def _stream_reader(self):
214        if self._stream_reader_wr is None:
215            return None
216        return self._stream_reader_wr()
217
218    def _replace_writer(self, writer):
219        loop = self._loop
220        transport = writer.transport
221        self._stream_writer = writer
222        self._transport = transport
223        self._over_ssl = transport.get_extra_info('sslcontext') is not None
224
225    def connection_made(self, transport):
226        if self._reject_connection:
227            context = {
228                'message': ('An open stream was garbage collected prior to '
229                            'establishing network connection; '
230                            'call "stream.close()" explicitly.')
231            }
232            if self._source_traceback:
233                context['source_traceback'] = self._source_traceback
234            self._loop.call_exception_handler(context)
235            transport.abort()
236            return
237        self._transport = transport
238        reader = self._stream_reader
239        if reader is not None:
240            reader.set_transport(transport)
241        self._over_ssl = transport.get_extra_info('sslcontext') is not None
242        if self._client_connected_cb is not None:
243            self._stream_writer = StreamWriter(transport, self,
244                                               reader,
245                                               self._loop)
246            res = self._client_connected_cb(reader,
247                                            self._stream_writer)
248            if coroutines.iscoroutine(res):
249                self._task = self._loop.create_task(res)
250            self._strong_reader = None
251
252    def connection_lost(self, exc):
253        reader = self._stream_reader
254        if reader is not None:
255            if exc is None:
256                reader.feed_eof()
257            else:
258                reader.set_exception(exc)
259        if not self._closed.done():
260            if exc is None:
261                self._closed.set_result(None)
262            else:
263                self._closed.set_exception(exc)
264        super().connection_lost(exc)
265        self._stream_reader_wr = None
266        self._stream_writer = None
267        self._task = None
268        self._transport = None
269
270    def data_received(self, data):
271        reader = self._stream_reader
272        if reader is not None:
273            reader.feed_data(data)
274
275    def eof_received(self):
276        reader = self._stream_reader
277        if reader is not None:
278            reader.feed_eof()
279        if self._over_ssl:
280            # Prevent a warning in SSLProtocol.eof_received:
281            # "returning true from eof_received()
282            # has no effect when using ssl"
283            return False
284        return True
285
286    def _get_close_waiter(self, stream):
287        return self._closed
288
289    def __del__(self):
290        # Prevent reports about unhandled exceptions.
291        # Better than self._closed._log_traceback = False hack
292        try:
293            closed = self._closed
294        except AttributeError:
295            pass  # failed constructor
296        else:
297            if closed.done() and not closed.cancelled():
298                closed.exception()
299
300
301class StreamWriter:
302    """Wraps a Transport.
303
304    This exposes write(), writelines(), [can_]write_eof(),
305    get_extra_info() and close().  It adds drain() which returns an
306    optional Future on which you can wait for flow control.  It also
307    adds a transport property which references the Transport
308    directly.
309    """
310
311    def __init__(self, transport, protocol, reader, loop):
312        self._transport = transport
313        self._protocol = protocol
314        # drain() expects that the reader has an exception() method
315        assert reader is None or isinstance(reader, StreamReader)
316        self._reader = reader
317        self._loop = loop
318        self._complete_fut = self._loop.create_future()
319        self._complete_fut.set_result(None)
320
321    def __repr__(self):
322        info = [self.__class__.__name__, f'transport={self._transport!r}']
323        if self._reader is not None:
324            info.append(f'reader={self._reader!r}')
325        return '<{}>'.format(' '.join(info))
326
327    @property
328    def transport(self):
329        return self._transport
330
331    def write(self, data):
332        self._transport.write(data)
333
334    def writelines(self, data):
335        self._transport.writelines(data)
336
337    def write_eof(self):
338        return self._transport.write_eof()
339
340    def can_write_eof(self):
341        return self._transport.can_write_eof()
342
343    def close(self):
344        return self._transport.close()
345
346    def is_closing(self):
347        return self._transport.is_closing()
348
349    async def wait_closed(self):
350        await self._protocol._get_close_waiter(self)
351
352    def get_extra_info(self, name, default=None):
353        return self._transport.get_extra_info(name, default)
354
355    async def drain(self):
356        """Flush the write buffer.
357
358        The intended use is to write
359
360          w.write(data)
361          await w.drain()
362        """
363        if self._reader is not None:
364            exc = self._reader.exception()
365            if exc is not None:
366                raise exc
367        if self._transport.is_closing():
368            # Wait for protocol.connection_lost() call
369            # Raise connection closing error if any,
370            # ConnectionResetError otherwise
371            # Yield to the event loop so connection_lost() may be
372            # called.  Without this, _drain_helper() would return
373            # immediately, and code that calls
374            #     write(...); await drain()
375            # in a loop would never call connection_lost(), so it
376            # would not see an error when the socket is closed.
377            await sleep(0)
378        await self._protocol._drain_helper()
379
380    async def start_tls(self, sslcontext, *,
381                        server_hostname=None,
382                        ssl_handshake_timeout=None):
383        """Upgrade an existing stream-based connection to TLS."""
384        server_side = self._protocol._client_connected_cb is not None
385        protocol = self._protocol
386        await self.drain()
387        new_transport = await self._loop.start_tls(  # type: ignore
388            self._transport, protocol, sslcontext,
389            server_side=server_side, server_hostname=server_hostname,
390            ssl_handshake_timeout=ssl_handshake_timeout)
391        self._transport = new_transport
392        protocol._replace_writer(self)
393
394
395class StreamReader:
396
397    _source_traceback = None
398
399    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
400        # The line length limit is  a security feature;
401        # it also doubles as half the buffer limit.
402
403        if limit <= 0:
404            raise ValueError('Limit cannot be <= 0')
405
406        self._limit = limit
407        if loop is None:
408            self._loop = events._get_event_loop()
409        else:
410            self._loop = loop
411        self._buffer = bytearray()
412        self._eof = False    # Whether we're done.
413        self._waiter = None  # A future used by _wait_for_data()
414        self._exception = None
415        self._transport = None
416        self._paused = False
417        if self._loop.get_debug():
418            self._source_traceback = format_helpers.extract_stack(
419                sys._getframe(1))
420
421    def __repr__(self):
422        info = ['StreamReader']
423        if self._buffer:
424            info.append(f'{len(self._buffer)} bytes')
425        if self._eof:
426            info.append('eof')
427        if self._limit != _DEFAULT_LIMIT:
428            info.append(f'limit={self._limit}')
429        if self._waiter:
430            info.append(f'waiter={self._waiter!r}')
431        if self._exception:
432            info.append(f'exception={self._exception!r}')
433        if self._transport:
434            info.append(f'transport={self._transport!r}')
435        if self._paused:
436            info.append('paused')
437        return '<{}>'.format(' '.join(info))
438
439    def exception(self):
440        return self._exception
441
442    def set_exception(self, exc):
443        self._exception = exc
444
445        waiter = self._waiter
446        if waiter is not None:
447            self._waiter = None
448            if not waiter.cancelled():
449                waiter.set_exception(exc)
450
451    def _wakeup_waiter(self):
452        """Wakeup read*() functions waiting for data or EOF."""
453        waiter = self._waiter
454        if waiter is not None:
455            self._waiter = None
456            if not waiter.cancelled():
457                waiter.set_result(None)
458
459    def set_transport(self, transport):
460        assert self._transport is None, 'Transport already set'
461        self._transport = transport
462
463    def _maybe_resume_transport(self):
464        if self._paused and len(self._buffer) <= self._limit:
465            self._paused = False
466            self._transport.resume_reading()
467
468    def feed_eof(self):
469        self._eof = True
470        self._wakeup_waiter()
471
472    def at_eof(self):
473        """Return True if the buffer is empty and 'feed_eof' was called."""
474        return self._eof and not self._buffer
475
476    def feed_data(self, data):
477        assert not self._eof, 'feed_data after feed_eof'
478
479        if not data:
480            return
481
482        self._buffer.extend(data)
483        self._wakeup_waiter()
484
485        if (self._transport is not None and
486                not self._paused and
487                len(self._buffer) > 2 * self._limit):
488            try:
489                self._transport.pause_reading()
490            except NotImplementedError:
491                # The transport can't be paused.
492                # We'll just have to buffer all data.
493                # Forget the transport so we don't keep trying.
494                self._transport = None
495            else:
496                self._paused = True
497
498    async def _wait_for_data(self, func_name):
499        """Wait until feed_data() or feed_eof() is called.
500
501        If stream was paused, automatically resume it.
502        """
503        # StreamReader uses a future to link the protocol feed_data() method
504        # to a read coroutine. Running two read coroutines at the same time
505        # would have an unexpected behaviour. It would not possible to know
506        # which coroutine would get the next data.
507        if self._waiter is not None:
508            raise RuntimeError(
509                f'{func_name}() called while another coroutine is '
510                f'already waiting for incoming data')
511
512        assert not self._eof, '_wait_for_data after EOF'
513
514        # Waiting for data while paused will make deadlock, so prevent it.
515        # This is essential for readexactly(n) for case when n > self._limit.
516        if self._paused:
517            self._paused = False
518            self._transport.resume_reading()
519
520        self._waiter = self._loop.create_future()
521        try:
522            await self._waiter
523        finally:
524            self._waiter = None
525
526    async def readline(self):
527        """Read chunk of data from the stream until newline (b'\n') is found.
528
529        On success, return chunk that ends with newline. If only partial
530        line can be read due to EOF, return incomplete line without
531        terminating newline. When EOF was reached while no bytes read, empty
532        bytes object is returned.
533
534        If limit is reached, ValueError will be raised. In that case, if
535        newline was found, complete line including newline will be removed
536        from internal buffer. Else, internal buffer will be cleared. Limit is
537        compared against part of the line without newline.
538
539        If stream was paused, this function will automatically resume it if
540        needed.
541        """
542        sep = b'\n'
543        seplen = len(sep)
544        try:
545            line = await self.readuntil(sep)
546        except exceptions.IncompleteReadError as e:
547            return e.partial
548        except exceptions.LimitOverrunError as e:
549            if self._buffer.startswith(sep, e.consumed):
550                del self._buffer[:e.consumed + seplen]
551            else:
552                self._buffer.clear()
553            self._maybe_resume_transport()
554            raise ValueError(e.args[0])
555        return line
556
557    async def readuntil(self, separator=b'\n'):
558        """Read data from the stream until ``separator`` is found.
559
560        On success, the data and separator will be removed from the
561        internal buffer (consumed). Returned data will include the
562        separator at the end.
563
564        Configured stream limit is used to check result. Limit sets the
565        maximal length of data that can be returned, not counting the
566        separator.
567
568        If an EOF occurs and the complete separator is still not found,
569        an IncompleteReadError exception will be raised, and the internal
570        buffer will be reset.  The IncompleteReadError.partial attribute
571        may contain the separator partially.
572
573        If the data cannot be read because of over limit, a
574        LimitOverrunError exception  will be raised, and the data
575        will be left in the internal buffer, so it can be read again.
576        """
577        seplen = len(separator)
578        if seplen == 0:
579            raise ValueError('Separator should be at least one-byte string')
580
581        if self._exception is not None:
582            raise self._exception
583
584        # Consume whole buffer except last bytes, which length is
585        # one less than seplen. Let's check corner cases with
586        # separator='SEPARATOR':
587        # * we have received almost complete separator (without last
588        #   byte). i.e buffer='some textSEPARATO'. In this case we
589        #   can safely consume len(separator) - 1 bytes.
590        # * last byte of buffer is first byte of separator, i.e.
591        #   buffer='abcdefghijklmnopqrS'. We may safely consume
592        #   everything except that last byte, but this require to
593        #   analyze bytes of buffer that match partial separator.
594        #   This is slow and/or require FSM. For this case our
595        #   implementation is not optimal, since require rescanning
596        #   of data that is known to not belong to separator. In
597        #   real world, separator will not be so long to notice
598        #   performance problems. Even when reading MIME-encoded
599        #   messages :)
600
601        # `offset` is the number of bytes from the beginning of the buffer
602        # where there is no occurrence of `separator`.
603        offset = 0
604
605        # Loop until we find `separator` in the buffer, exceed the buffer size,
606        # or an EOF has happened.
607        while True:
608            buflen = len(self._buffer)
609
610            # Check if we now have enough data in the buffer for `separator` to
611            # fit.
612            if buflen - offset >= seplen:
613                isep = self._buffer.find(separator, offset)
614
615                if isep != -1:
616                    # `separator` is in the buffer. `isep` will be used later
617                    # to retrieve the data.
618                    break
619
620                # see upper comment for explanation.
621                offset = buflen + 1 - seplen
622                if offset > self._limit:
623                    raise exceptions.LimitOverrunError(
624                        'Separator is not found, and chunk exceed the limit',
625                        offset)
626
627            # Complete message (with full separator) may be present in buffer
628            # even when EOF flag is set. This may happen when the last chunk
629            # adds data which makes separator be found. That's why we check for
630            # EOF *ater* inspecting the buffer.
631            if self._eof:
632                chunk = bytes(self._buffer)
633                self._buffer.clear()
634                raise exceptions.IncompleteReadError(chunk, None)
635
636            # _wait_for_data() will resume reading if stream was paused.
637            await self._wait_for_data('readuntil')
638
639        if isep > self._limit:
640            raise exceptions.LimitOverrunError(
641                'Separator is found, but chunk is longer than limit', isep)
642
643        chunk = self._buffer[:isep + seplen]
644        del self._buffer[:isep + seplen]
645        self._maybe_resume_transport()
646        return bytes(chunk)
647
648    async def read(self, n=-1):
649        """Read up to `n` bytes from the stream.
650
651        If `n` is not provided or set to -1,
652        read until EOF, then return all read bytes.
653        If EOF was received and the internal buffer is empty,
654        return an empty bytes object.
655
656        If `n` is 0, return an empty bytes object immediately.
657
658        If `n` is positive, return at most `n` available bytes
659        as soon as at least 1 byte is available in the internal buffer.
660        If EOF is received before any byte is read, return an empty
661        bytes object.
662
663        Returned value is not limited with limit, configured at stream
664        creation.
665
666        If stream was paused, this function will automatically resume it if
667        needed.
668        """
669
670        if self._exception is not None:
671            raise self._exception
672
673        if n == 0:
674            return b''
675
676        if n < 0:
677            # This used to just loop creating a new waiter hoping to
678            # collect everything in self._buffer, but that would
679            # deadlock if the subprocess sends more than self.limit
680            # bytes.  So just call self.read(self._limit) until EOF.
681            blocks = []
682            while True:
683                block = await self.read(self._limit)
684                if not block:
685                    break
686                blocks.append(block)
687            return b''.join(blocks)
688
689        if not self._buffer and not self._eof:
690            await self._wait_for_data('read')
691
692        # This will work right even if buffer is less than n bytes
693        data = bytes(self._buffer[:n])
694        del self._buffer[:n]
695
696        self._maybe_resume_transport()
697        return data
698
699    async def readexactly(self, n):
700        """Read exactly `n` bytes.
701
702        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
703        read. The IncompleteReadError.partial attribute of the exception will
704        contain the partial read bytes.
705
706        if n is zero, return empty bytes object.
707
708        Returned value is not limited with limit, configured at stream
709        creation.
710
711        If stream was paused, this function will automatically resume it if
712        needed.
713        """
714        if n < 0:
715            raise ValueError('readexactly size can not be less than zero')
716
717        if self._exception is not None:
718            raise self._exception
719
720        if n == 0:
721            return b''
722
723        while len(self._buffer) < n:
724            if self._eof:
725                incomplete = bytes(self._buffer)
726                self._buffer.clear()
727                raise exceptions.IncompleteReadError(incomplete, n)
728
729            await self._wait_for_data('readexactly')
730
731        if len(self._buffer) == n:
732            data = bytes(self._buffer)
733            self._buffer.clear()
734        else:
735            data = bytes(self._buffer[:n])
736            del self._buffer[:n]
737        self._maybe_resume_transport()
738        return data
739
740    def __aiter__(self):
741        return self
742
743    async def __anext__(self):
744        val = await self.readline()
745        if val == b'':
746            raise StopAsyncIteration
747        return val
748