1__all__ = ( 2 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 3 'open_connection', 'start_server') 4 5import collections 6import socket 7import sys 8import warnings 9import weakref 10 11if hasattr(socket, 'AF_UNIX'): 12 __all__ += ('open_unix_connection', 'start_unix_server') 13 14from . import coroutines 15from . import events 16from . import exceptions 17from . import format_helpers 18from . import protocols 19from .log import logger 20from .tasks import sleep 21 22 23_DEFAULT_LIMIT = 2 ** 16 # 64 KiB 24 25 26async def open_connection(host=None, port=None, *, 27 limit=_DEFAULT_LIMIT, **kwds): 28 """A wrapper for create_connection() returning a (reader, writer) pair. 29 30 The reader returned is a StreamReader instance; the writer is a 31 StreamWriter instance. 32 33 The arguments are all the usual arguments to create_connection() 34 except protocol_factory; most common are positional host and port, 35 with various optional keyword arguments following. 36 37 Additional optional keyword arguments are loop (to set the event loop 38 instance to use) and limit (to set the buffer limit passed to the 39 StreamReader). 40 41 (If you want to customize the StreamReader and/or 42 StreamReaderProtocol classes, just copy the code -- there's 43 really nothing special here except some convenience.) 44 """ 45 loop = events.get_running_loop() 46 reader = StreamReader(limit=limit, loop=loop) 47 protocol = StreamReaderProtocol(reader, loop=loop) 48 transport, _ = await loop.create_connection( 49 lambda: protocol, host, port, **kwds) 50 writer = StreamWriter(transport, protocol, reader, loop) 51 return reader, writer 52 53 54async def start_server(client_connected_cb, host=None, port=None, *, 55 limit=_DEFAULT_LIMIT, **kwds): 56 """Start a socket server, call back for each client connected. 57 58 The first parameter, `client_connected_cb`, takes two parameters: 59 client_reader, client_writer. client_reader is a StreamReader 60 object, while client_writer is a StreamWriter object. This 61 parameter can either be a plain callback function or a coroutine; 62 if it is a coroutine, it will be automatically converted into a 63 Task. 64 65 The rest of the arguments are all the usual arguments to 66 loop.create_server() except protocol_factory; most common are 67 positional host and port, with various optional keyword arguments 68 following. The return value is the same as loop.create_server(). 69 70 Additional optional keyword arguments are loop (to set the event loop 71 instance to use) and limit (to set the buffer limit passed to the 72 StreamReader). 73 74 The return value is the same as loop.create_server(), i.e. a 75 Server object which can be used to stop the service. 76 """ 77 loop = events.get_running_loop() 78 79 def factory(): 80 reader = StreamReader(limit=limit, loop=loop) 81 protocol = StreamReaderProtocol(reader, client_connected_cb, 82 loop=loop) 83 return protocol 84 85 return await loop.create_server(factory, host, port, **kwds) 86 87 88if hasattr(socket, 'AF_UNIX'): 89 # UNIX Domain Sockets are supported on this platform 90 91 async def open_unix_connection(path=None, *, 92 limit=_DEFAULT_LIMIT, **kwds): 93 """Similar to `open_connection` but works with UNIX Domain Sockets.""" 94 loop = events.get_running_loop() 95 96 reader = StreamReader(limit=limit, loop=loop) 97 protocol = StreamReaderProtocol(reader, loop=loop) 98 transport, _ = await loop.create_unix_connection( 99 lambda: protocol, path, **kwds) 100 writer = StreamWriter(transport, protocol, reader, loop) 101 return reader, writer 102 103 async def start_unix_server(client_connected_cb, path=None, *, 104 limit=_DEFAULT_LIMIT, **kwds): 105 """Similar to `start_server` but works with UNIX Domain Sockets.""" 106 loop = events.get_running_loop() 107 108 def factory(): 109 reader = StreamReader(limit=limit, loop=loop) 110 protocol = StreamReaderProtocol(reader, client_connected_cb, 111 loop=loop) 112 return protocol 113 114 return await loop.create_unix_server(factory, path, **kwds) 115 116 117class FlowControlMixin(protocols.Protocol): 118 """Reusable flow control logic for StreamWriter.drain(). 119 120 This implements the protocol methods pause_writing(), 121 resume_writing() and connection_lost(). If the subclass overrides 122 these it must call the super methods. 123 124 StreamWriter.drain() must wait for _drain_helper() coroutine. 125 """ 126 127 def __init__(self, loop=None): 128 if loop is None: 129 self._loop = events._get_event_loop(stacklevel=4) 130 else: 131 self._loop = loop 132 self._paused = False 133 self._drain_waiters = collections.deque() 134 self._connection_lost = False 135 136 def pause_writing(self): 137 assert not self._paused 138 self._paused = True 139 if self._loop.get_debug(): 140 logger.debug("%r pauses writing", self) 141 142 def resume_writing(self): 143 assert self._paused 144 self._paused = False 145 if self._loop.get_debug(): 146 logger.debug("%r resumes writing", self) 147 148 for waiter in self._drain_waiters: 149 if not waiter.done(): 150 waiter.set_result(None) 151 152 def connection_lost(self, exc): 153 self._connection_lost = True 154 # Wake up the writer(s) if currently paused. 155 if not self._paused: 156 return 157 158 for waiter in self._drain_waiters: 159 if not waiter.done(): 160 if exc is None: 161 waiter.set_result(None) 162 else: 163 waiter.set_exception(exc) 164 165 async def _drain_helper(self): 166 if self._connection_lost: 167 raise ConnectionResetError('Connection lost') 168 if not self._paused: 169 return 170 waiter = self._loop.create_future() 171 self._drain_waiters.append(waiter) 172 try: 173 await waiter 174 finally: 175 self._drain_waiters.remove(waiter) 176 177 def _get_close_waiter(self, stream): 178 raise NotImplementedError 179 180 181class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): 182 """Helper class to adapt between Protocol and StreamReader. 183 184 (This is a helper class instead of making StreamReader itself a 185 Protocol subclass, because the StreamReader has other potential 186 uses, and to prevent the user of the StreamReader to accidentally 187 call inappropriate methods of the protocol.) 188 """ 189 190 _source_traceback = None 191 192 def __init__(self, stream_reader, client_connected_cb=None, loop=None): 193 super().__init__(loop=loop) 194 if stream_reader is not None: 195 self._stream_reader_wr = weakref.ref(stream_reader) 196 self._source_traceback = stream_reader._source_traceback 197 else: 198 self._stream_reader_wr = None 199 if client_connected_cb is not None: 200 # This is a stream created by the `create_server()` function. 201 # Keep a strong reference to the reader until a connection 202 # is established. 203 self._strong_reader = stream_reader 204 self._reject_connection = False 205 self._stream_writer = None 206 self._task = None 207 self._transport = None 208 self._client_connected_cb = client_connected_cb 209 self._over_ssl = False 210 self._closed = self._loop.create_future() 211 212 @property 213 def _stream_reader(self): 214 if self._stream_reader_wr is None: 215 return None 216 return self._stream_reader_wr() 217 218 def _replace_writer(self, writer): 219 loop = self._loop 220 transport = writer.transport 221 self._stream_writer = writer 222 self._transport = transport 223 self._over_ssl = transport.get_extra_info('sslcontext') is not None 224 225 def connection_made(self, transport): 226 if self._reject_connection: 227 context = { 228 'message': ('An open stream was garbage collected prior to ' 229 'establishing network connection; ' 230 'call "stream.close()" explicitly.') 231 } 232 if self._source_traceback: 233 context['source_traceback'] = self._source_traceback 234 self._loop.call_exception_handler(context) 235 transport.abort() 236 return 237 self._transport = transport 238 reader = self._stream_reader 239 if reader is not None: 240 reader.set_transport(transport) 241 self._over_ssl = transport.get_extra_info('sslcontext') is not None 242 if self._client_connected_cb is not None: 243 self._stream_writer = StreamWriter(transport, self, 244 reader, 245 self._loop) 246 res = self._client_connected_cb(reader, 247 self._stream_writer) 248 if coroutines.iscoroutine(res): 249 self._task = self._loop.create_task(res) 250 self._strong_reader = None 251 252 def connection_lost(self, exc): 253 reader = self._stream_reader 254 if reader is not None: 255 if exc is None: 256 reader.feed_eof() 257 else: 258 reader.set_exception(exc) 259 if not self._closed.done(): 260 if exc is None: 261 self._closed.set_result(None) 262 else: 263 self._closed.set_exception(exc) 264 super().connection_lost(exc) 265 self._stream_reader_wr = None 266 self._stream_writer = None 267 self._task = None 268 self._transport = None 269 270 def data_received(self, data): 271 reader = self._stream_reader 272 if reader is not None: 273 reader.feed_data(data) 274 275 def eof_received(self): 276 reader = self._stream_reader 277 if reader is not None: 278 reader.feed_eof() 279 if self._over_ssl: 280 # Prevent a warning in SSLProtocol.eof_received: 281 # "returning true from eof_received() 282 # has no effect when using ssl" 283 return False 284 return True 285 286 def _get_close_waiter(self, stream): 287 return self._closed 288 289 def __del__(self): 290 # Prevent reports about unhandled exceptions. 291 # Better than self._closed._log_traceback = False hack 292 try: 293 closed = self._closed 294 except AttributeError: 295 pass # failed constructor 296 else: 297 if closed.done() and not closed.cancelled(): 298 closed.exception() 299 300 301class StreamWriter: 302 """Wraps a Transport. 303 304 This exposes write(), writelines(), [can_]write_eof(), 305 get_extra_info() and close(). It adds drain() which returns an 306 optional Future on which you can wait for flow control. It also 307 adds a transport property which references the Transport 308 directly. 309 """ 310 311 def __init__(self, transport, protocol, reader, loop): 312 self._transport = transport 313 self._protocol = protocol 314 # drain() expects that the reader has an exception() method 315 assert reader is None or isinstance(reader, StreamReader) 316 self._reader = reader 317 self._loop = loop 318 self._complete_fut = self._loop.create_future() 319 self._complete_fut.set_result(None) 320 321 def __repr__(self): 322 info = [self.__class__.__name__, f'transport={self._transport!r}'] 323 if self._reader is not None: 324 info.append(f'reader={self._reader!r}') 325 return '<{}>'.format(' '.join(info)) 326 327 @property 328 def transport(self): 329 return self._transport 330 331 def write(self, data): 332 self._transport.write(data) 333 334 def writelines(self, data): 335 self._transport.writelines(data) 336 337 def write_eof(self): 338 return self._transport.write_eof() 339 340 def can_write_eof(self): 341 return self._transport.can_write_eof() 342 343 def close(self): 344 return self._transport.close() 345 346 def is_closing(self): 347 return self._transport.is_closing() 348 349 async def wait_closed(self): 350 await self._protocol._get_close_waiter(self) 351 352 def get_extra_info(self, name, default=None): 353 return self._transport.get_extra_info(name, default) 354 355 async def drain(self): 356 """Flush the write buffer. 357 358 The intended use is to write 359 360 w.write(data) 361 await w.drain() 362 """ 363 if self._reader is not None: 364 exc = self._reader.exception() 365 if exc is not None: 366 raise exc 367 if self._transport.is_closing(): 368 # Wait for protocol.connection_lost() call 369 # Raise connection closing error if any, 370 # ConnectionResetError otherwise 371 # Yield to the event loop so connection_lost() may be 372 # called. Without this, _drain_helper() would return 373 # immediately, and code that calls 374 # write(...); await drain() 375 # in a loop would never call connection_lost(), so it 376 # would not see an error when the socket is closed. 377 await sleep(0) 378 await self._protocol._drain_helper() 379 380 async def start_tls(self, sslcontext, *, 381 server_hostname=None, 382 ssl_handshake_timeout=None): 383 """Upgrade an existing stream-based connection to TLS.""" 384 server_side = self._protocol._client_connected_cb is not None 385 protocol = self._protocol 386 await self.drain() 387 new_transport = await self._loop.start_tls( # type: ignore 388 self._transport, protocol, sslcontext, 389 server_side=server_side, server_hostname=server_hostname, 390 ssl_handshake_timeout=ssl_handshake_timeout) 391 self._transport = new_transport 392 protocol._replace_writer(self) 393 394 395class StreamReader: 396 397 _source_traceback = None 398 399 def __init__(self, limit=_DEFAULT_LIMIT, loop=None): 400 # The line length limit is a security feature; 401 # it also doubles as half the buffer limit. 402 403 if limit <= 0: 404 raise ValueError('Limit cannot be <= 0') 405 406 self._limit = limit 407 if loop is None: 408 self._loop = events._get_event_loop() 409 else: 410 self._loop = loop 411 self._buffer = bytearray() 412 self._eof = False # Whether we're done. 413 self._waiter = None # A future used by _wait_for_data() 414 self._exception = None 415 self._transport = None 416 self._paused = False 417 if self._loop.get_debug(): 418 self._source_traceback = format_helpers.extract_stack( 419 sys._getframe(1)) 420 421 def __repr__(self): 422 info = ['StreamReader'] 423 if self._buffer: 424 info.append(f'{len(self._buffer)} bytes') 425 if self._eof: 426 info.append('eof') 427 if self._limit != _DEFAULT_LIMIT: 428 info.append(f'limit={self._limit}') 429 if self._waiter: 430 info.append(f'waiter={self._waiter!r}') 431 if self._exception: 432 info.append(f'exception={self._exception!r}') 433 if self._transport: 434 info.append(f'transport={self._transport!r}') 435 if self._paused: 436 info.append('paused') 437 return '<{}>'.format(' '.join(info)) 438 439 def exception(self): 440 return self._exception 441 442 def set_exception(self, exc): 443 self._exception = exc 444 445 waiter = self._waiter 446 if waiter is not None: 447 self._waiter = None 448 if not waiter.cancelled(): 449 waiter.set_exception(exc) 450 451 def _wakeup_waiter(self): 452 """Wakeup read*() functions waiting for data or EOF.""" 453 waiter = self._waiter 454 if waiter is not None: 455 self._waiter = None 456 if not waiter.cancelled(): 457 waiter.set_result(None) 458 459 def set_transport(self, transport): 460 assert self._transport is None, 'Transport already set' 461 self._transport = transport 462 463 def _maybe_resume_transport(self): 464 if self._paused and len(self._buffer) <= self._limit: 465 self._paused = False 466 self._transport.resume_reading() 467 468 def feed_eof(self): 469 self._eof = True 470 self._wakeup_waiter() 471 472 def at_eof(self): 473 """Return True if the buffer is empty and 'feed_eof' was called.""" 474 return self._eof and not self._buffer 475 476 def feed_data(self, data): 477 assert not self._eof, 'feed_data after feed_eof' 478 479 if not data: 480 return 481 482 self._buffer.extend(data) 483 self._wakeup_waiter() 484 485 if (self._transport is not None and 486 not self._paused and 487 len(self._buffer) > 2 * self._limit): 488 try: 489 self._transport.pause_reading() 490 except NotImplementedError: 491 # The transport can't be paused. 492 # We'll just have to buffer all data. 493 # Forget the transport so we don't keep trying. 494 self._transport = None 495 else: 496 self._paused = True 497 498 async def _wait_for_data(self, func_name): 499 """Wait until feed_data() or feed_eof() is called. 500 501 If stream was paused, automatically resume it. 502 """ 503 # StreamReader uses a future to link the protocol feed_data() method 504 # to a read coroutine. Running two read coroutines at the same time 505 # would have an unexpected behaviour. It would not possible to know 506 # which coroutine would get the next data. 507 if self._waiter is not None: 508 raise RuntimeError( 509 f'{func_name}() called while another coroutine is ' 510 f'already waiting for incoming data') 511 512 assert not self._eof, '_wait_for_data after EOF' 513 514 # Waiting for data while paused will make deadlock, so prevent it. 515 # This is essential for readexactly(n) for case when n > self._limit. 516 if self._paused: 517 self._paused = False 518 self._transport.resume_reading() 519 520 self._waiter = self._loop.create_future() 521 try: 522 await self._waiter 523 finally: 524 self._waiter = None 525 526 async def readline(self): 527 """Read chunk of data from the stream until newline (b'\n') is found. 528 529 On success, return chunk that ends with newline. If only partial 530 line can be read due to EOF, return incomplete line without 531 terminating newline. When EOF was reached while no bytes read, empty 532 bytes object is returned. 533 534 If limit is reached, ValueError will be raised. In that case, if 535 newline was found, complete line including newline will be removed 536 from internal buffer. Else, internal buffer will be cleared. Limit is 537 compared against part of the line without newline. 538 539 If stream was paused, this function will automatically resume it if 540 needed. 541 """ 542 sep = b'\n' 543 seplen = len(sep) 544 try: 545 line = await self.readuntil(sep) 546 except exceptions.IncompleteReadError as e: 547 return e.partial 548 except exceptions.LimitOverrunError as e: 549 if self._buffer.startswith(sep, e.consumed): 550 del self._buffer[:e.consumed + seplen] 551 else: 552 self._buffer.clear() 553 self._maybe_resume_transport() 554 raise ValueError(e.args[0]) 555 return line 556 557 async def readuntil(self, separator=b'\n'): 558 """Read data from the stream until ``separator`` is found. 559 560 On success, the data and separator will be removed from the 561 internal buffer (consumed). Returned data will include the 562 separator at the end. 563 564 Configured stream limit is used to check result. Limit sets the 565 maximal length of data that can be returned, not counting the 566 separator. 567 568 If an EOF occurs and the complete separator is still not found, 569 an IncompleteReadError exception will be raised, and the internal 570 buffer will be reset. The IncompleteReadError.partial attribute 571 may contain the separator partially. 572 573 If the data cannot be read because of over limit, a 574 LimitOverrunError exception will be raised, and the data 575 will be left in the internal buffer, so it can be read again. 576 """ 577 seplen = len(separator) 578 if seplen == 0: 579 raise ValueError('Separator should be at least one-byte string') 580 581 if self._exception is not None: 582 raise self._exception 583 584 # Consume whole buffer except last bytes, which length is 585 # one less than seplen. Let's check corner cases with 586 # separator='SEPARATOR': 587 # * we have received almost complete separator (without last 588 # byte). i.e buffer='some textSEPARATO'. In this case we 589 # can safely consume len(separator) - 1 bytes. 590 # * last byte of buffer is first byte of separator, i.e. 591 # buffer='abcdefghijklmnopqrS'. We may safely consume 592 # everything except that last byte, but this require to 593 # analyze bytes of buffer that match partial separator. 594 # This is slow and/or require FSM. For this case our 595 # implementation is not optimal, since require rescanning 596 # of data that is known to not belong to separator. In 597 # real world, separator will not be so long to notice 598 # performance problems. Even when reading MIME-encoded 599 # messages :) 600 601 # `offset` is the number of bytes from the beginning of the buffer 602 # where there is no occurrence of `separator`. 603 offset = 0 604 605 # Loop until we find `separator` in the buffer, exceed the buffer size, 606 # or an EOF has happened. 607 while True: 608 buflen = len(self._buffer) 609 610 # Check if we now have enough data in the buffer for `separator` to 611 # fit. 612 if buflen - offset >= seplen: 613 isep = self._buffer.find(separator, offset) 614 615 if isep != -1: 616 # `separator` is in the buffer. `isep` will be used later 617 # to retrieve the data. 618 break 619 620 # see upper comment for explanation. 621 offset = buflen + 1 - seplen 622 if offset > self._limit: 623 raise exceptions.LimitOverrunError( 624 'Separator is not found, and chunk exceed the limit', 625 offset) 626 627 # Complete message (with full separator) may be present in buffer 628 # even when EOF flag is set. This may happen when the last chunk 629 # adds data which makes separator be found. That's why we check for 630 # EOF *ater* inspecting the buffer. 631 if self._eof: 632 chunk = bytes(self._buffer) 633 self._buffer.clear() 634 raise exceptions.IncompleteReadError(chunk, None) 635 636 # _wait_for_data() will resume reading if stream was paused. 637 await self._wait_for_data('readuntil') 638 639 if isep > self._limit: 640 raise exceptions.LimitOverrunError( 641 'Separator is found, but chunk is longer than limit', isep) 642 643 chunk = self._buffer[:isep + seplen] 644 del self._buffer[:isep + seplen] 645 self._maybe_resume_transport() 646 return bytes(chunk) 647 648 async def read(self, n=-1): 649 """Read up to `n` bytes from the stream. 650 651 If `n` is not provided or set to -1, 652 read until EOF, then return all read bytes. 653 If EOF was received and the internal buffer is empty, 654 return an empty bytes object. 655 656 If `n` is 0, return an empty bytes object immediately. 657 658 If `n` is positive, return at most `n` available bytes 659 as soon as at least 1 byte is available in the internal buffer. 660 If EOF is received before any byte is read, return an empty 661 bytes object. 662 663 Returned value is not limited with limit, configured at stream 664 creation. 665 666 If stream was paused, this function will automatically resume it if 667 needed. 668 """ 669 670 if self._exception is not None: 671 raise self._exception 672 673 if n == 0: 674 return b'' 675 676 if n < 0: 677 # This used to just loop creating a new waiter hoping to 678 # collect everything in self._buffer, but that would 679 # deadlock if the subprocess sends more than self.limit 680 # bytes. So just call self.read(self._limit) until EOF. 681 blocks = [] 682 while True: 683 block = await self.read(self._limit) 684 if not block: 685 break 686 blocks.append(block) 687 return b''.join(blocks) 688 689 if not self._buffer and not self._eof: 690 await self._wait_for_data('read') 691 692 # This will work right even if buffer is less than n bytes 693 data = bytes(self._buffer[:n]) 694 del self._buffer[:n] 695 696 self._maybe_resume_transport() 697 return data 698 699 async def readexactly(self, n): 700 """Read exactly `n` bytes. 701 702 Raise an IncompleteReadError if EOF is reached before `n` bytes can be 703 read. The IncompleteReadError.partial attribute of the exception will 704 contain the partial read bytes. 705 706 if n is zero, return empty bytes object. 707 708 Returned value is not limited with limit, configured at stream 709 creation. 710 711 If stream was paused, this function will automatically resume it if 712 needed. 713 """ 714 if n < 0: 715 raise ValueError('readexactly size can not be less than zero') 716 717 if self._exception is not None: 718 raise self._exception 719 720 if n == 0: 721 return b'' 722 723 while len(self._buffer) < n: 724 if self._eof: 725 incomplete = bytes(self._buffer) 726 self._buffer.clear() 727 raise exceptions.IncompleteReadError(incomplete, n) 728 729 await self._wait_for_data('readexactly') 730 731 if len(self._buffer) == n: 732 data = bytes(self._buffer) 733 self._buffer.clear() 734 else: 735 data = bytes(self._buffer[:n]) 736 del self._buffer[:n] 737 self._maybe_resume_transport() 738 return data 739 740 def __aiter__(self): 741 return self 742 743 async def __anext__(self): 744 val = await self.readline() 745 if val == b'': 746 raise StopAsyncIteration 747 return val 748