1import asyncio 2import asyncio.sslproto 3import contextlib 4import gc 5import logging 6import select 7import socket 8import sys 9import tempfile 10import threading 11import time 12import weakref 13import unittest 14 15try: 16 import ssl 17except ImportError: 18 ssl = None 19 20from test import support 21from test.test_asyncio import utils as test_utils 22 23 24MACOS = (sys.platform == 'darwin') 25BUF_MULTIPLIER = 1024 if not MACOS else 64 26 27 28def tearDownModule(): 29 asyncio.set_event_loop_policy(None) 30 31 32class MyBaseProto(asyncio.Protocol): 33 connected = None 34 done = None 35 36 def __init__(self, loop=None): 37 self.transport = None 38 self.state = 'INITIAL' 39 self.nbytes = 0 40 if loop is not None: 41 self.connected = asyncio.Future(loop=loop) 42 self.done = asyncio.Future(loop=loop) 43 44 def connection_made(self, transport): 45 self.transport = transport 46 assert self.state == 'INITIAL', self.state 47 self.state = 'CONNECTED' 48 if self.connected: 49 self.connected.set_result(None) 50 51 def data_received(self, data): 52 assert self.state == 'CONNECTED', self.state 53 self.nbytes += len(data) 54 55 def eof_received(self): 56 assert self.state == 'CONNECTED', self.state 57 self.state = 'EOF' 58 59 def connection_lost(self, exc): 60 assert self.state in ('CONNECTED', 'EOF'), self.state 61 self.state = 'CLOSED' 62 if self.done: 63 self.done.set_result(None) 64 65 66class MessageOutFilter(logging.Filter): 67 def __init__(self, msg): 68 self.msg = msg 69 70 def filter(self, record): 71 if self.msg in record.msg: 72 return False 73 return True 74 75 76@unittest.skipIf(ssl is None, 'No ssl module') 77class TestSSL(test_utils.TestCase): 78 79 PAYLOAD_SIZE = 1024 * 100 80 TIMEOUT = support.LONG_TIMEOUT 81 82 def setUp(self): 83 super().setUp() 84 self.loop = asyncio.new_event_loop() 85 self.set_event_loop(self.loop) 86 self.addCleanup(self.loop.close) 87 88 def tearDown(self): 89 # just in case if we have transport close callbacks 90 if not self.loop.is_closed(): 91 test_utils.run_briefly(self.loop) 92 93 self.doCleanups() 94 support.gc_collect() 95 super().tearDown() 96 97 def tcp_server(self, server_prog, *, 98 family=socket.AF_INET, 99 addr=None, 100 timeout=support.SHORT_TIMEOUT, 101 backlog=1, 102 max_clients=10): 103 104 if addr is None: 105 if family == getattr(socket, "AF_UNIX", None): 106 with tempfile.NamedTemporaryFile() as tmp: 107 addr = tmp.name 108 else: 109 addr = ('127.0.0.1', 0) 110 111 sock = socket.socket(family, socket.SOCK_STREAM) 112 113 if timeout is None: 114 raise RuntimeError('timeout is required') 115 if timeout <= 0: 116 raise RuntimeError('only blocking sockets are supported') 117 sock.settimeout(timeout) 118 119 try: 120 sock.bind(addr) 121 sock.listen(backlog) 122 except OSError as ex: 123 sock.close() 124 raise ex 125 126 return TestThreadedServer( 127 self, sock, server_prog, timeout, max_clients) 128 129 def tcp_client(self, client_prog, 130 family=socket.AF_INET, 131 timeout=support.SHORT_TIMEOUT): 132 133 sock = socket.socket(family, socket.SOCK_STREAM) 134 135 if timeout is None: 136 raise RuntimeError('timeout is required') 137 if timeout <= 0: 138 raise RuntimeError('only blocking sockets are supported') 139 sock.settimeout(timeout) 140 141 return TestThreadedClient( 142 self, sock, client_prog, timeout) 143 144 def unix_server(self, *args, **kwargs): 145 return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs) 146 147 def unix_client(self, *args, **kwargs): 148 return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs) 149 150 def _create_server_ssl_context(self, certfile, keyfile=None): 151 sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 152 sslcontext.options |= ssl.OP_NO_SSLv2 153 sslcontext.load_cert_chain(certfile, keyfile) 154 return sslcontext 155 156 def _create_client_ssl_context(self, *, disable_verify=True): 157 sslcontext = ssl.create_default_context() 158 sslcontext.check_hostname = False 159 if disable_verify: 160 sslcontext.verify_mode = ssl.CERT_NONE 161 return sslcontext 162 163 @contextlib.contextmanager 164 def _silence_eof_received_warning(self): 165 # TODO This warning has to be fixed in asyncio. 166 logger = logging.getLogger('asyncio') 167 filter = MessageOutFilter('has no effect when using ssl') 168 logger.addFilter(filter) 169 try: 170 yield 171 finally: 172 logger.removeFilter(filter) 173 174 def _abort_socket_test(self, ex): 175 try: 176 self.loop.stop() 177 finally: 178 self.fail(ex) 179 180 def new_loop(self): 181 return asyncio.new_event_loop() 182 183 def new_policy(self): 184 return asyncio.DefaultEventLoopPolicy() 185 186 async def wait_closed(self, obj): 187 if not isinstance(obj, asyncio.StreamWriter): 188 return 189 try: 190 await obj.wait_closed() 191 except (BrokenPipeError, ConnectionError): 192 pass 193 194 def test_create_server_ssl_1(self): 195 CNT = 0 # number of clients that were successful 196 TOTAL_CNT = 25 # total number of clients that test will create 197 TIMEOUT = support.LONG_TIMEOUT # timeout for this test 198 199 A_DATA = b'A' * 1024 * BUF_MULTIPLIER 200 B_DATA = b'B' * 1024 * BUF_MULTIPLIER 201 202 sslctx = self._create_server_ssl_context( 203 test_utils.ONLYCERT, test_utils.ONLYKEY 204 ) 205 client_sslctx = self._create_client_ssl_context() 206 207 clients = [] 208 209 async def handle_client(reader, writer): 210 nonlocal CNT 211 212 data = await reader.readexactly(len(A_DATA)) 213 self.assertEqual(data, A_DATA) 214 writer.write(b'OK') 215 216 data = await reader.readexactly(len(B_DATA)) 217 self.assertEqual(data, B_DATA) 218 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) 219 220 await writer.drain() 221 writer.close() 222 223 CNT += 1 224 225 async def test_client(addr): 226 fut = asyncio.Future() 227 228 def prog(sock): 229 try: 230 sock.starttls(client_sslctx) 231 sock.connect(addr) 232 sock.send(A_DATA) 233 234 data = sock.recv_all(2) 235 self.assertEqual(data, b'OK') 236 237 sock.send(B_DATA) 238 data = sock.recv_all(4) 239 self.assertEqual(data, b'SPAM') 240 241 sock.close() 242 243 except Exception as ex: 244 self.loop.call_soon_threadsafe(fut.set_exception, ex) 245 else: 246 self.loop.call_soon_threadsafe(fut.set_result, None) 247 248 client = self.tcp_client(prog) 249 client.start() 250 clients.append(client) 251 252 await fut 253 254 async def start_server(): 255 extras = {} 256 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) 257 258 srv = await asyncio.start_server( 259 handle_client, 260 '127.0.0.1', 0, 261 family=socket.AF_INET, 262 ssl=sslctx, 263 **extras) 264 265 try: 266 srv_socks = srv.sockets 267 self.assertTrue(srv_socks) 268 269 addr = srv_socks[0].getsockname() 270 271 tasks = [] 272 for _ in range(TOTAL_CNT): 273 tasks.append(test_client(addr)) 274 275 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 276 277 finally: 278 self.loop.call_soon(srv.close) 279 await srv.wait_closed() 280 281 with self._silence_eof_received_warning(): 282 self.loop.run_until_complete(start_server()) 283 284 self.assertEqual(CNT, TOTAL_CNT) 285 286 for client in clients: 287 client.stop() 288 289 def test_create_connection_ssl_1(self): 290 self.loop.set_exception_handler(None) 291 292 CNT = 0 293 TOTAL_CNT = 25 294 295 A_DATA = b'A' * 1024 * BUF_MULTIPLIER 296 B_DATA = b'B' * 1024 * BUF_MULTIPLIER 297 298 sslctx = self._create_server_ssl_context( 299 test_utils.ONLYCERT, 300 test_utils.ONLYKEY 301 ) 302 client_sslctx = self._create_client_ssl_context() 303 304 def server(sock): 305 sock.starttls( 306 sslctx, 307 server_side=True) 308 309 data = sock.recv_all(len(A_DATA)) 310 self.assertEqual(data, A_DATA) 311 sock.send(b'OK') 312 313 data = sock.recv_all(len(B_DATA)) 314 self.assertEqual(data, B_DATA) 315 sock.send(b'SPAM') 316 317 sock.close() 318 319 async def client(addr): 320 extras = {} 321 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) 322 323 reader, writer = await asyncio.open_connection( 324 *addr, 325 ssl=client_sslctx, 326 server_hostname='', 327 **extras) 328 329 writer.write(A_DATA) 330 self.assertEqual(await reader.readexactly(2), b'OK') 331 332 writer.write(B_DATA) 333 self.assertEqual(await reader.readexactly(4), b'SPAM') 334 335 nonlocal CNT 336 CNT += 1 337 338 writer.close() 339 await self.wait_closed(writer) 340 341 async def client_sock(addr): 342 sock = socket.socket() 343 sock.connect(addr) 344 reader, writer = await asyncio.open_connection( 345 sock=sock, 346 ssl=client_sslctx, 347 server_hostname='') 348 349 writer.write(A_DATA) 350 self.assertEqual(await reader.readexactly(2), b'OK') 351 352 writer.write(B_DATA) 353 self.assertEqual(await reader.readexactly(4), b'SPAM') 354 355 nonlocal CNT 356 CNT += 1 357 358 writer.close() 359 await self.wait_closed(writer) 360 sock.close() 361 362 def run(coro): 363 nonlocal CNT 364 CNT = 0 365 366 async def _gather(*tasks): 367 # trampoline 368 return await asyncio.gather(*tasks) 369 370 with self.tcp_server(server, 371 max_clients=TOTAL_CNT, 372 backlog=TOTAL_CNT) as srv: 373 tasks = [] 374 for _ in range(TOTAL_CNT): 375 tasks.append(coro(srv.addr)) 376 377 self.loop.run_until_complete(_gather(*tasks)) 378 379 self.assertEqual(CNT, TOTAL_CNT) 380 381 with self._silence_eof_received_warning(): 382 run(client) 383 384 with self._silence_eof_received_warning(): 385 run(client_sock) 386 387 def test_create_connection_ssl_slow_handshake(self): 388 client_sslctx = self._create_client_ssl_context() 389 390 # silence error logger 391 self.loop.set_exception_handler(lambda *args: None) 392 393 def server(sock): 394 try: 395 sock.recv_all(1024 * 1024) 396 except ConnectionAbortedError: 397 pass 398 finally: 399 sock.close() 400 401 async def client(addr): 402 reader, writer = await asyncio.open_connection( 403 *addr, 404 ssl=client_sslctx, 405 server_hostname='', 406 ssl_handshake_timeout=1.0) 407 writer.close() 408 await self.wait_closed(writer) 409 410 with self.tcp_server(server, 411 max_clients=1, 412 backlog=1) as srv: 413 414 with self.assertRaisesRegex( 415 ConnectionAbortedError, 416 r'SSL handshake.*is taking longer'): 417 418 self.loop.run_until_complete(client(srv.addr)) 419 420 def test_create_connection_ssl_failed_certificate(self): 421 # silence error logger 422 self.loop.set_exception_handler(lambda *args: None) 423 424 sslctx = self._create_server_ssl_context( 425 test_utils.ONLYCERT, 426 test_utils.ONLYKEY 427 ) 428 client_sslctx = self._create_client_ssl_context(disable_verify=False) 429 430 def server(sock): 431 try: 432 sock.starttls( 433 sslctx, 434 server_side=True) 435 sock.connect() 436 except (ssl.SSLError, OSError): 437 pass 438 finally: 439 sock.close() 440 441 async def client(addr): 442 reader, writer = await asyncio.open_connection( 443 *addr, 444 ssl=client_sslctx, 445 server_hostname='', 446 ssl_handshake_timeout=support.SHORT_TIMEOUT) 447 writer.close() 448 await self.wait_closed(writer) 449 450 with self.tcp_server(server, 451 max_clients=1, 452 backlog=1) as srv: 453 454 with self.assertRaises(ssl.SSLCertVerificationError): 455 self.loop.run_until_complete(client(srv.addr)) 456 457 def test_ssl_handshake_timeout(self): 458 # bpo-29970: Check that a connection is aborted if handshake is not 459 # completed in timeout period, instead of remaining open indefinitely 460 client_sslctx = test_utils.simple_client_sslcontext() 461 462 # silence error logger 463 messages = [] 464 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 465 466 server_side_aborted = False 467 468 def server(sock): 469 nonlocal server_side_aborted 470 try: 471 sock.recv_all(1024 * 1024) 472 except ConnectionAbortedError: 473 server_side_aborted = True 474 finally: 475 sock.close() 476 477 async def client(addr): 478 await asyncio.wait_for( 479 self.loop.create_connection( 480 asyncio.Protocol, 481 *addr, 482 ssl=client_sslctx, 483 server_hostname='', 484 ssl_handshake_timeout=10.0), 485 0.5) 486 487 with self.tcp_server(server, 488 max_clients=1, 489 backlog=1) as srv: 490 491 with self.assertRaises(asyncio.TimeoutError): 492 self.loop.run_until_complete(client(srv.addr)) 493 494 self.assertTrue(server_side_aborted) 495 496 # Python issue #23197: cancelling a handshake must not raise an 497 # exception or log an error, even if the handshake failed 498 self.assertEqual(messages, []) 499 500 def test_ssl_handshake_connection_lost(self): 501 # #246: make sure that no connection_lost() is called before 502 # connection_made() is called first 503 504 client_sslctx = test_utils.simple_client_sslcontext() 505 506 # silence error logger 507 self.loop.set_exception_handler(lambda loop, ctx: None) 508 509 connection_made_called = False 510 connection_lost_called = False 511 512 def server(sock): 513 sock.recv(1024) 514 # break the connection during handshake 515 sock.close() 516 517 class ClientProto(asyncio.Protocol): 518 def connection_made(self, transport): 519 nonlocal connection_made_called 520 connection_made_called = True 521 522 def connection_lost(self, exc): 523 nonlocal connection_lost_called 524 connection_lost_called = True 525 526 async def client(addr): 527 await self.loop.create_connection( 528 ClientProto, 529 *addr, 530 ssl=client_sslctx, 531 server_hostname=''), 532 533 with self.tcp_server(server, 534 max_clients=1, 535 backlog=1) as srv: 536 537 with self.assertRaises(ConnectionResetError): 538 self.loop.run_until_complete(client(srv.addr)) 539 540 if connection_lost_called: 541 if connection_made_called: 542 self.fail("unexpected call to connection_lost()") 543 else: 544 self.fail("unexpected call to connection_lost() without" 545 "calling connection_made()") 546 elif connection_made_called: 547 self.fail("unexpected call to connection_made()") 548 549 def test_ssl_connect_accepted_socket(self): 550 proto = ssl.PROTOCOL_TLS_SERVER 551 server_context = ssl.SSLContext(proto) 552 server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY) 553 if hasattr(server_context, 'check_hostname'): 554 server_context.check_hostname = False 555 server_context.verify_mode = ssl.CERT_NONE 556 557 client_context = ssl.SSLContext(proto) 558 if hasattr(server_context, 'check_hostname'): 559 client_context.check_hostname = False 560 client_context.verify_mode = ssl.CERT_NONE 561 562 def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): 563 loop = self.loop 564 565 class MyProto(MyBaseProto): 566 567 def connection_lost(self, exc): 568 super().connection_lost(exc) 569 loop.call_soon(loop.stop) 570 571 def data_received(self, data): 572 super().data_received(data) 573 self.transport.write(expected_response) 574 575 lsock = socket.socket(socket.AF_INET) 576 lsock.bind(('127.0.0.1', 0)) 577 lsock.listen(1) 578 addr = lsock.getsockname() 579 580 message = b'test data' 581 response = None 582 expected_response = b'roger' 583 584 def client(): 585 nonlocal response 586 try: 587 csock = socket.socket(socket.AF_INET) 588 if client_ssl is not None: 589 csock = client_ssl.wrap_socket(csock) 590 csock.connect(addr) 591 csock.sendall(message) 592 response = csock.recv(99) 593 csock.close() 594 except Exception as exc: 595 print( 596 "Failure in client thread in test_connect_accepted_socket", 597 exc) 598 599 thread = threading.Thread(target=client, daemon=True) 600 thread.start() 601 602 conn, _ = lsock.accept() 603 proto = MyProto(loop=loop) 604 proto.loop = loop 605 606 extras = {} 607 if server_ssl: 608 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) 609 610 f = loop.create_task( 611 loop.connect_accepted_socket( 612 (lambda: proto), conn, ssl=server_ssl, 613 **extras)) 614 loop.run_forever() 615 conn.close() 616 lsock.close() 617 618 thread.join(1) 619 self.assertFalse(thread.is_alive()) 620 self.assertEqual(proto.state, 'CLOSED') 621 self.assertEqual(proto.nbytes, len(message)) 622 self.assertEqual(response, expected_response) 623 tr, _ = f.result() 624 625 if server_ssl: 626 self.assertIn('SSL', tr.__class__.__name__) 627 628 tr.close() 629 # let it close 630 self.loop.run_until_complete(asyncio.sleep(0.1)) 631 632 def test_start_tls_client_corrupted_ssl(self): 633 self.loop.set_exception_handler(lambda loop, ctx: None) 634 635 sslctx = test_utils.simple_server_sslcontext() 636 client_sslctx = test_utils.simple_client_sslcontext() 637 638 def server(sock): 639 orig_sock = sock.dup() 640 try: 641 sock.starttls( 642 sslctx, 643 server_side=True) 644 sock.sendall(b'A\n') 645 sock.recv_all(1) 646 orig_sock.send(b'please corrupt the SSL connection') 647 except ssl.SSLError: 648 pass 649 finally: 650 sock.close() 651 orig_sock.close() 652 653 async def client(addr): 654 reader, writer = await asyncio.open_connection( 655 *addr, 656 ssl=client_sslctx, 657 server_hostname='') 658 659 self.assertEqual(await reader.readline(), b'A\n') 660 writer.write(b'B') 661 with self.assertRaises(ssl.SSLError): 662 await reader.readline() 663 writer.close() 664 try: 665 await self.wait_closed(writer) 666 except ssl.SSLError: 667 pass 668 return 'OK' 669 670 with self.tcp_server(server, 671 max_clients=1, 672 backlog=1) as srv: 673 674 res = self.loop.run_until_complete(client(srv.addr)) 675 676 self.assertEqual(res, 'OK') 677 678 def test_start_tls_client_reg_proto_1(self): 679 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 680 681 server_context = test_utils.simple_server_sslcontext() 682 client_context = test_utils.simple_client_sslcontext() 683 684 def serve(sock): 685 sock.settimeout(self.TIMEOUT) 686 687 data = sock.recv_all(len(HELLO_MSG)) 688 self.assertEqual(len(data), len(HELLO_MSG)) 689 690 sock.starttls(server_context, server_side=True) 691 692 sock.sendall(b'O') 693 data = sock.recv_all(len(HELLO_MSG)) 694 self.assertEqual(len(data), len(HELLO_MSG)) 695 696 sock.unwrap() 697 sock.close() 698 699 class ClientProto(asyncio.Protocol): 700 def __init__(self, on_data, on_eof): 701 self.on_data = on_data 702 self.on_eof = on_eof 703 self.con_made_cnt = 0 704 705 def connection_made(proto, tr): 706 proto.con_made_cnt += 1 707 # Ensure connection_made gets called only once. 708 self.assertEqual(proto.con_made_cnt, 1) 709 710 def data_received(self, data): 711 self.on_data.set_result(data) 712 713 def eof_received(self): 714 self.on_eof.set_result(True) 715 716 async def client(addr): 717 await asyncio.sleep(0.5) 718 719 on_data = self.loop.create_future() 720 on_eof = self.loop.create_future() 721 722 tr, proto = await self.loop.create_connection( 723 lambda: ClientProto(on_data, on_eof), *addr) 724 725 tr.write(HELLO_MSG) 726 new_tr = await self.loop.start_tls(tr, proto, client_context) 727 728 self.assertEqual(await on_data, b'O') 729 new_tr.write(HELLO_MSG) 730 await on_eof 731 732 new_tr.close() 733 734 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 735 self.loop.run_until_complete( 736 asyncio.wait_for(client(srv.addr), 737 timeout=support.SHORT_TIMEOUT)) 738 739 def test_create_connection_memory_leak(self): 740 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 741 742 server_context = self._create_server_ssl_context( 743 test_utils.ONLYCERT, test_utils.ONLYKEY) 744 client_context = self._create_client_ssl_context() 745 746 def serve(sock): 747 sock.settimeout(self.TIMEOUT) 748 749 sock.starttls(server_context, server_side=True) 750 751 sock.sendall(b'O') 752 data = sock.recv_all(len(HELLO_MSG)) 753 self.assertEqual(len(data), len(HELLO_MSG)) 754 755 sock.unwrap() 756 sock.close() 757 758 class ClientProto(asyncio.Protocol): 759 def __init__(self, on_data, on_eof): 760 self.on_data = on_data 761 self.on_eof = on_eof 762 self.con_made_cnt = 0 763 764 def connection_made(proto, tr): 765 # XXX: We assume user stores the transport in protocol 766 proto.tr = tr 767 proto.con_made_cnt += 1 768 # Ensure connection_made gets called only once. 769 self.assertEqual(proto.con_made_cnt, 1) 770 771 def data_received(self, data): 772 self.on_data.set_result(data) 773 774 def eof_received(self): 775 self.on_eof.set_result(True) 776 777 async def client(addr): 778 await asyncio.sleep(0.5) 779 780 on_data = self.loop.create_future() 781 on_eof = self.loop.create_future() 782 783 tr, proto = await self.loop.create_connection( 784 lambda: ClientProto(on_data, on_eof), *addr, 785 ssl=client_context) 786 787 self.assertEqual(await on_data, b'O') 788 tr.write(HELLO_MSG) 789 await on_eof 790 791 tr.close() 792 793 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 794 self.loop.run_until_complete( 795 asyncio.wait_for(client(srv.addr), 796 timeout=support.SHORT_TIMEOUT)) 797 798 # No garbage is left for SSL client from loop.create_connection, even 799 # if user stores the SSLTransport in corresponding protocol instance 800 client_context = weakref.ref(client_context) 801 self.assertIsNone(client_context()) 802 803 def test_start_tls_client_buf_proto_1(self): 804 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 805 806 server_context = test_utils.simple_server_sslcontext() 807 client_context = test_utils.simple_client_sslcontext() 808 809 client_con_made_calls = 0 810 811 def serve(sock): 812 sock.settimeout(self.TIMEOUT) 813 814 data = sock.recv_all(len(HELLO_MSG)) 815 self.assertEqual(len(data), len(HELLO_MSG)) 816 817 sock.starttls(server_context, server_side=True) 818 819 sock.sendall(b'O') 820 data = sock.recv_all(len(HELLO_MSG)) 821 self.assertEqual(len(data), len(HELLO_MSG)) 822 823 sock.sendall(b'2') 824 data = sock.recv_all(len(HELLO_MSG)) 825 self.assertEqual(len(data), len(HELLO_MSG)) 826 827 sock.unwrap() 828 sock.close() 829 830 class ClientProtoFirst(asyncio.BufferedProtocol): 831 def __init__(self, on_data): 832 self.on_data = on_data 833 self.buf = bytearray(1) 834 835 def connection_made(self, tr): 836 nonlocal client_con_made_calls 837 client_con_made_calls += 1 838 839 def get_buffer(self, sizehint): 840 return self.buf 841 842 def buffer_updated(self, nsize): 843 assert nsize == 1 844 self.on_data.set_result(bytes(self.buf[:nsize])) 845 846 def eof_received(self): 847 pass 848 849 class ClientProtoSecond(asyncio.Protocol): 850 def __init__(self, on_data, on_eof): 851 self.on_data = on_data 852 self.on_eof = on_eof 853 self.con_made_cnt = 0 854 855 def connection_made(self, tr): 856 nonlocal client_con_made_calls 857 client_con_made_calls += 1 858 859 def data_received(self, data): 860 self.on_data.set_result(data) 861 862 def eof_received(self): 863 self.on_eof.set_result(True) 864 865 async def client(addr): 866 await asyncio.sleep(0.5) 867 868 on_data1 = self.loop.create_future() 869 on_data2 = self.loop.create_future() 870 on_eof = self.loop.create_future() 871 872 tr, proto = await self.loop.create_connection( 873 lambda: ClientProtoFirst(on_data1), *addr) 874 875 tr.write(HELLO_MSG) 876 new_tr = await self.loop.start_tls(tr, proto, client_context) 877 878 self.assertEqual(await on_data1, b'O') 879 new_tr.write(HELLO_MSG) 880 881 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) 882 self.assertEqual(await on_data2, b'2') 883 new_tr.write(HELLO_MSG) 884 await on_eof 885 886 new_tr.close() 887 888 # connection_made() should be called only once -- when 889 # we establish connection for the first time. Start TLS 890 # doesn't call connection_made() on application protocols. 891 self.assertEqual(client_con_made_calls, 1) 892 893 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 894 self.loop.run_until_complete( 895 asyncio.wait_for(client(srv.addr), 896 timeout=self.TIMEOUT)) 897 898 def test_start_tls_slow_client_cancel(self): 899 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 900 901 client_context = test_utils.simple_client_sslcontext() 902 server_waits_on_handshake = self.loop.create_future() 903 904 def serve(sock): 905 sock.settimeout(self.TIMEOUT) 906 907 data = sock.recv_all(len(HELLO_MSG)) 908 self.assertEqual(len(data), len(HELLO_MSG)) 909 910 try: 911 self.loop.call_soon_threadsafe( 912 server_waits_on_handshake.set_result, None) 913 data = sock.recv_all(1024 * 1024) 914 except ConnectionAbortedError: 915 pass 916 finally: 917 sock.close() 918 919 class ClientProto(asyncio.Protocol): 920 def __init__(self, on_data, on_eof): 921 self.on_data = on_data 922 self.on_eof = on_eof 923 self.con_made_cnt = 0 924 925 def connection_made(proto, tr): 926 proto.con_made_cnt += 1 927 # Ensure connection_made gets called only once. 928 self.assertEqual(proto.con_made_cnt, 1) 929 930 def data_received(self, data): 931 self.on_data.set_result(data) 932 933 def eof_received(self): 934 self.on_eof.set_result(True) 935 936 async def client(addr): 937 await asyncio.sleep(0.5) 938 939 on_data = self.loop.create_future() 940 on_eof = self.loop.create_future() 941 942 tr, proto = await self.loop.create_connection( 943 lambda: ClientProto(on_data, on_eof), *addr) 944 945 tr.write(HELLO_MSG) 946 947 await server_waits_on_handshake 948 949 with self.assertRaises(asyncio.TimeoutError): 950 await asyncio.wait_for( 951 self.loop.start_tls(tr, proto, client_context), 952 0.5) 953 954 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 955 self.loop.run_until_complete( 956 asyncio.wait_for(client(srv.addr), 957 timeout=support.SHORT_TIMEOUT)) 958 959 def test_start_tls_server_1(self): 960 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 961 962 server_context = test_utils.simple_server_sslcontext() 963 client_context = test_utils.simple_client_sslcontext() 964 965 def client(sock, addr): 966 sock.settimeout(self.TIMEOUT) 967 968 sock.connect(addr) 969 data = sock.recv_all(len(HELLO_MSG)) 970 self.assertEqual(len(data), len(HELLO_MSG)) 971 972 sock.starttls(client_context) 973 sock.sendall(HELLO_MSG) 974 975 sock.unwrap() 976 sock.close() 977 978 class ServerProto(asyncio.Protocol): 979 def __init__(self, on_con, on_eof, on_con_lost): 980 self.on_con = on_con 981 self.on_eof = on_eof 982 self.on_con_lost = on_con_lost 983 self.data = b'' 984 985 def connection_made(self, tr): 986 self.on_con.set_result(tr) 987 988 def data_received(self, data): 989 self.data += data 990 991 def eof_received(self): 992 self.on_eof.set_result(1) 993 994 def connection_lost(self, exc): 995 if exc is None: 996 self.on_con_lost.set_result(None) 997 else: 998 self.on_con_lost.set_exception(exc) 999 1000 async def main(proto, on_con, on_eof, on_con_lost): 1001 tr = await on_con 1002 tr.write(HELLO_MSG) 1003 1004 self.assertEqual(proto.data, b'') 1005 1006 new_tr = await self.loop.start_tls( 1007 tr, proto, server_context, 1008 server_side=True, 1009 ssl_handshake_timeout=self.TIMEOUT) 1010 1011 await on_eof 1012 await on_con_lost 1013 self.assertEqual(proto.data, HELLO_MSG) 1014 new_tr.close() 1015 1016 async def run_main(): 1017 on_con = self.loop.create_future() 1018 on_eof = self.loop.create_future() 1019 on_con_lost = self.loop.create_future() 1020 proto = ServerProto(on_con, on_eof, on_con_lost) 1021 1022 server = await self.loop.create_server( 1023 lambda: proto, '127.0.0.1', 0) 1024 addr = server.sockets[0].getsockname() 1025 1026 with self.tcp_client(lambda sock: client(sock, addr), 1027 timeout=self.TIMEOUT): 1028 await asyncio.wait_for( 1029 main(proto, on_con, on_eof, on_con_lost), 1030 timeout=self.TIMEOUT) 1031 1032 server.close() 1033 await server.wait_closed() 1034 1035 self.loop.run_until_complete(run_main()) 1036 1037 def test_create_server_ssl_over_ssl(self): 1038 CNT = 0 # number of clients that were successful 1039 TOTAL_CNT = 25 # total number of clients that test will create 1040 TIMEOUT = support.LONG_TIMEOUT # timeout for this test 1041 1042 A_DATA = b'A' * 1024 * BUF_MULTIPLIER 1043 B_DATA = b'B' * 1024 * BUF_MULTIPLIER 1044 1045 sslctx_1 = self._create_server_ssl_context( 1046 test_utils.ONLYCERT, test_utils.ONLYKEY) 1047 client_sslctx_1 = self._create_client_ssl_context() 1048 sslctx_2 = self._create_server_ssl_context( 1049 test_utils.ONLYCERT, test_utils.ONLYKEY) 1050 client_sslctx_2 = self._create_client_ssl_context() 1051 1052 clients = [] 1053 1054 async def handle_client(reader, writer): 1055 nonlocal CNT 1056 1057 data = await reader.readexactly(len(A_DATA)) 1058 self.assertEqual(data, A_DATA) 1059 writer.write(b'OK') 1060 1061 data = await reader.readexactly(len(B_DATA)) 1062 self.assertEqual(data, B_DATA) 1063 writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')]) 1064 1065 await writer.drain() 1066 writer.close() 1067 1068 CNT += 1 1069 1070 class ServerProtocol(asyncio.StreamReaderProtocol): 1071 def connection_made(self, transport): 1072 super_ = super() 1073 transport.pause_reading() 1074 fut = self._loop.create_task(self._loop.start_tls( 1075 transport, self, sslctx_2, server_side=True)) 1076 1077 def cb(_): 1078 try: 1079 tr = fut.result() 1080 except Exception as ex: 1081 super_.connection_lost(ex) 1082 else: 1083 super_.connection_made(tr) 1084 fut.add_done_callback(cb) 1085 1086 def server_protocol_factory(): 1087 reader = asyncio.StreamReader() 1088 protocol = ServerProtocol(reader, handle_client) 1089 return protocol 1090 1091 async def test_client(addr): 1092 fut = asyncio.Future() 1093 1094 def prog(sock): 1095 try: 1096 sock.connect(addr) 1097 sock.starttls(client_sslctx_1) 1098 1099 # because wrap_socket() doesn't work correctly on 1100 # SSLSocket, we have to do the 2nd level SSL manually 1101 incoming = ssl.MemoryBIO() 1102 outgoing = ssl.MemoryBIO() 1103 sslobj = client_sslctx_2.wrap_bio(incoming, outgoing) 1104 1105 def do(func, *args): 1106 while True: 1107 try: 1108 rv = func(*args) 1109 break 1110 except ssl.SSLWantReadError: 1111 if outgoing.pending: 1112 sock.send(outgoing.read()) 1113 incoming.write(sock.recv(65536)) 1114 if outgoing.pending: 1115 sock.send(outgoing.read()) 1116 return rv 1117 1118 do(sslobj.do_handshake) 1119 1120 do(sslobj.write, A_DATA) 1121 data = do(sslobj.read, 2) 1122 self.assertEqual(data, b'OK') 1123 1124 do(sslobj.write, B_DATA) 1125 data = b'' 1126 while True: 1127 chunk = do(sslobj.read, 4) 1128 if not chunk: 1129 break 1130 data += chunk 1131 self.assertEqual(data, b'SPAM') 1132 1133 do(sslobj.unwrap) 1134 sock.close() 1135 1136 except Exception as ex: 1137 self.loop.call_soon_threadsafe(fut.set_exception, ex) 1138 sock.close() 1139 else: 1140 self.loop.call_soon_threadsafe(fut.set_result, None) 1141 1142 client = self.tcp_client(prog) 1143 client.start() 1144 clients.append(client) 1145 1146 await fut 1147 1148 async def start_server(): 1149 extras = {} 1150 1151 srv = await self.loop.create_server( 1152 server_protocol_factory, 1153 '127.0.0.1', 0, 1154 family=socket.AF_INET, 1155 ssl=sslctx_1, 1156 **extras) 1157 1158 try: 1159 srv_socks = srv.sockets 1160 self.assertTrue(srv_socks) 1161 1162 addr = srv_socks[0].getsockname() 1163 1164 tasks = [] 1165 for _ in range(TOTAL_CNT): 1166 tasks.append(test_client(addr)) 1167 1168 await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT) 1169 1170 finally: 1171 self.loop.call_soon(srv.close) 1172 await srv.wait_closed() 1173 1174 with self._silence_eof_received_warning(): 1175 self.loop.run_until_complete(start_server()) 1176 1177 self.assertEqual(CNT, TOTAL_CNT) 1178 1179 for client in clients: 1180 client.stop() 1181 1182 def test_shutdown_cleanly(self): 1183 CNT = 0 1184 TOTAL_CNT = 25 1185 1186 A_DATA = b'A' * 1024 * BUF_MULTIPLIER 1187 1188 sslctx = self._create_server_ssl_context( 1189 test_utils.ONLYCERT, test_utils.ONLYKEY) 1190 client_sslctx = self._create_client_ssl_context() 1191 1192 def server(sock): 1193 sock.starttls( 1194 sslctx, 1195 server_side=True) 1196 1197 data = sock.recv_all(len(A_DATA)) 1198 self.assertEqual(data, A_DATA) 1199 sock.send(b'OK') 1200 1201 sock.unwrap() 1202 1203 sock.close() 1204 1205 async def client(addr): 1206 extras = {} 1207 extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT) 1208 1209 reader, writer = await asyncio.open_connection( 1210 *addr, 1211 ssl=client_sslctx, 1212 server_hostname='', 1213 **extras) 1214 1215 writer.write(A_DATA) 1216 self.assertEqual(await reader.readexactly(2), b'OK') 1217 1218 self.assertEqual(await reader.read(), b'') 1219 1220 nonlocal CNT 1221 CNT += 1 1222 1223 writer.close() 1224 await self.wait_closed(writer) 1225 1226 def run(coro): 1227 nonlocal CNT 1228 CNT = 0 1229 1230 async def _gather(*tasks): 1231 return await asyncio.gather(*tasks) 1232 1233 with self.tcp_server(server, 1234 max_clients=TOTAL_CNT, 1235 backlog=TOTAL_CNT) as srv: 1236 tasks = [] 1237 for _ in range(TOTAL_CNT): 1238 tasks.append(coro(srv.addr)) 1239 1240 self.loop.run_until_complete( 1241 _gather(*tasks)) 1242 1243 self.assertEqual(CNT, TOTAL_CNT) 1244 1245 with self._silence_eof_received_warning(): 1246 run(client) 1247 1248 def test_flush_before_shutdown(self): 1249 CHUNK = 1024 * 128 1250 SIZE = 32 1251 1252 sslctx = self._create_server_ssl_context( 1253 test_utils.ONLYCERT, test_utils.ONLYKEY) 1254 client_sslctx = self._create_client_ssl_context() 1255 1256 future = None 1257 1258 def server(sock): 1259 sock.starttls(sslctx, server_side=True) 1260 self.assertEqual(sock.recv_all(4), b'ping') 1261 sock.send(b'pong') 1262 time.sleep(0.5) # hopefully stuck the TCP buffer 1263 data = sock.recv_all(CHUNK * SIZE) 1264 self.assertEqual(len(data), CHUNK * SIZE) 1265 sock.close() 1266 1267 def run(meth): 1268 def wrapper(sock): 1269 try: 1270 meth(sock) 1271 except Exception as ex: 1272 self.loop.call_soon_threadsafe(future.set_exception, ex) 1273 else: 1274 self.loop.call_soon_threadsafe(future.set_result, None) 1275 return wrapper 1276 1277 async def client(addr): 1278 nonlocal future 1279 future = self.loop.create_future() 1280 reader, writer = await asyncio.open_connection( 1281 *addr, 1282 ssl=client_sslctx, 1283 server_hostname='') 1284 sslprotocol = writer.transport._ssl_protocol 1285 writer.write(b'ping') 1286 data = await reader.readexactly(4) 1287 self.assertEqual(data, b'pong') 1288 1289 sslprotocol.pause_writing() 1290 for _ in range(SIZE): 1291 writer.write(b'x' * CHUNK) 1292 1293 writer.close() 1294 sslprotocol.resume_writing() 1295 1296 await self.wait_closed(writer) 1297 try: 1298 data = await reader.read() 1299 self.assertEqual(data, b'') 1300 except ConnectionResetError: 1301 pass 1302 await future 1303 1304 with self.tcp_server(run(server)) as srv: 1305 self.loop.run_until_complete(client(srv.addr)) 1306 1307 def test_remote_shutdown_receives_trailing_data(self): 1308 CHUNK = 1024 * 128 1309 SIZE = 32 1310 1311 sslctx = self._create_server_ssl_context( 1312 test_utils.ONLYCERT, 1313 test_utils.ONLYKEY 1314 ) 1315 client_sslctx = self._create_client_ssl_context() 1316 future = None 1317 1318 def server(sock): 1319 incoming = ssl.MemoryBIO() 1320 outgoing = ssl.MemoryBIO() 1321 sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True) 1322 1323 while True: 1324 try: 1325 sslobj.do_handshake() 1326 except ssl.SSLWantReadError: 1327 if outgoing.pending: 1328 sock.send(outgoing.read()) 1329 incoming.write(sock.recv(16384)) 1330 else: 1331 if outgoing.pending: 1332 sock.send(outgoing.read()) 1333 break 1334 1335 while True: 1336 try: 1337 data = sslobj.read(4) 1338 except ssl.SSLWantReadError: 1339 incoming.write(sock.recv(16384)) 1340 else: 1341 break 1342 1343 self.assertEqual(data, b'ping') 1344 sslobj.write(b'pong') 1345 sock.send(outgoing.read()) 1346 1347 time.sleep(0.2) # wait for the peer to fill its backlog 1348 1349 # send close_notify but don't wait for response 1350 with self.assertRaises(ssl.SSLWantReadError): 1351 sslobj.unwrap() 1352 sock.send(outgoing.read()) 1353 1354 # should receive all data 1355 data_len = 0 1356 while True: 1357 try: 1358 chunk = len(sslobj.read(16384)) 1359 data_len += chunk 1360 except ssl.SSLWantReadError: 1361 incoming.write(sock.recv(16384)) 1362 except ssl.SSLZeroReturnError: 1363 break 1364 1365 self.assertEqual(data_len, CHUNK * SIZE) 1366 1367 # verify that close_notify is received 1368 sslobj.unwrap() 1369 1370 sock.close() 1371 1372 def eof_server(sock): 1373 sock.starttls(sslctx, server_side=True) 1374 self.assertEqual(sock.recv_all(4), b'ping') 1375 sock.send(b'pong') 1376 1377 time.sleep(0.2) # wait for the peer to fill its backlog 1378 1379 # send EOF 1380 sock.shutdown(socket.SHUT_WR) 1381 1382 # should receive all data 1383 data = sock.recv_all(CHUNK * SIZE) 1384 self.assertEqual(len(data), CHUNK * SIZE) 1385 1386 sock.close() 1387 1388 async def client(addr): 1389 nonlocal future 1390 future = self.loop.create_future() 1391 1392 reader, writer = await asyncio.open_connection( 1393 *addr, 1394 ssl=client_sslctx, 1395 server_hostname='') 1396 writer.write(b'ping') 1397 data = await reader.readexactly(4) 1398 self.assertEqual(data, b'pong') 1399 1400 # fill write backlog in a hacky way - renegotiation won't help 1401 for _ in range(SIZE): 1402 writer.transport._test__append_write_backlog(b'x' * CHUNK) 1403 1404 try: 1405 data = await reader.read() 1406 self.assertEqual(data, b'') 1407 except (BrokenPipeError, ConnectionResetError): 1408 pass 1409 1410 await future 1411 1412 writer.close() 1413 await self.wait_closed(writer) 1414 1415 def run(meth): 1416 def wrapper(sock): 1417 try: 1418 meth(sock) 1419 except Exception as ex: 1420 self.loop.call_soon_threadsafe(future.set_exception, ex) 1421 else: 1422 self.loop.call_soon_threadsafe(future.set_result, None) 1423 return wrapper 1424 1425 with self.tcp_server(run(server)) as srv: 1426 self.loop.run_until_complete(client(srv.addr)) 1427 1428 with self.tcp_server(run(eof_server)) as srv: 1429 self.loop.run_until_complete(client(srv.addr)) 1430 1431 def test_connect_timeout_warning(self): 1432 s = socket.socket(socket.AF_INET) 1433 s.bind(('127.0.0.1', 0)) 1434 addr = s.getsockname() 1435 1436 async def test(): 1437 try: 1438 await asyncio.wait_for( 1439 self.loop.create_connection(asyncio.Protocol, 1440 *addr, ssl=True), 1441 0.1) 1442 except (ConnectionRefusedError, asyncio.TimeoutError): 1443 pass 1444 else: 1445 self.fail('TimeoutError is not raised') 1446 1447 with s: 1448 try: 1449 with self.assertWarns(ResourceWarning) as cm: 1450 self.loop.run_until_complete(test()) 1451 gc.collect() 1452 gc.collect() 1453 gc.collect() 1454 except AssertionError as e: 1455 self.assertEqual(str(e), 'ResourceWarning not triggered') 1456 else: 1457 self.fail('Unexpected ResourceWarning: {}'.format(cm.warning)) 1458 1459 def test_handshake_timeout_handler_leak(self): 1460 s = socket.socket(socket.AF_INET) 1461 s.bind(('127.0.0.1', 0)) 1462 s.listen(1) 1463 addr = s.getsockname() 1464 1465 async def test(ctx): 1466 try: 1467 await asyncio.wait_for( 1468 self.loop.create_connection(asyncio.Protocol, *addr, 1469 ssl=ctx), 1470 0.1) 1471 except (ConnectionRefusedError, asyncio.TimeoutError): 1472 pass 1473 else: 1474 self.fail('TimeoutError is not raised') 1475 1476 with s: 1477 ctx = ssl.create_default_context() 1478 self.loop.run_until_complete(test(ctx)) 1479 ctx = weakref.ref(ctx) 1480 1481 # SSLProtocol should be DECREF to 0 1482 self.assertIsNone(ctx()) 1483 1484 def test_shutdown_timeout_handler_leak(self): 1485 loop = self.loop 1486 1487 def server(sock): 1488 sslctx = self._create_server_ssl_context( 1489 test_utils.ONLYCERT, 1490 test_utils.ONLYKEY 1491 ) 1492 sock = sslctx.wrap_socket(sock, server_side=True) 1493 sock.recv(32) 1494 sock.close() 1495 1496 class Protocol(asyncio.Protocol): 1497 def __init__(self): 1498 self.fut = asyncio.Future(loop=loop) 1499 1500 def connection_lost(self, exc): 1501 self.fut.set_result(None) 1502 1503 async def client(addr, ctx): 1504 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) 1505 tr.close() 1506 await pr.fut 1507 1508 with self.tcp_server(server) as srv: 1509 ctx = self._create_client_ssl_context() 1510 loop.run_until_complete(client(srv.addr, ctx)) 1511 ctx = weakref.ref(ctx) 1512 1513 # asyncio has no shutdown timeout, but it ends up with a circular 1514 # reference loop - not ideal (introduces gc glitches), but at least 1515 # not leaking 1516 gc.collect() 1517 gc.collect() 1518 gc.collect() 1519 1520 # SSLProtocol should be DECREF to 0 1521 self.assertIsNone(ctx()) 1522 1523 def test_shutdown_timeout_handler_not_set(self): 1524 loop = self.loop 1525 eof = asyncio.Event() 1526 extra = None 1527 1528 def server(sock): 1529 sslctx = self._create_server_ssl_context( 1530 test_utils.ONLYCERT, 1531 test_utils.ONLYKEY 1532 ) 1533 sock = sslctx.wrap_socket(sock, server_side=True) 1534 sock.send(b'hello') 1535 assert sock.recv(1024) == b'world' 1536 sock.send(b'extra bytes') 1537 # sending EOF here 1538 sock.shutdown(socket.SHUT_WR) 1539 loop.call_soon_threadsafe(eof.set) 1540 # make sure we have enough time to reproduce the issue 1541 assert sock.recv(1024) == b'' 1542 sock.close() 1543 1544 class Protocol(asyncio.Protocol): 1545 def __init__(self): 1546 self.fut = asyncio.Future(loop=loop) 1547 self.transport = None 1548 1549 def connection_made(self, transport): 1550 self.transport = transport 1551 1552 def data_received(self, data): 1553 if data == b'hello': 1554 self.transport.write(b'world') 1555 # pause reading would make incoming data stay in the sslobj 1556 self.transport.pause_reading() 1557 else: 1558 nonlocal extra 1559 extra = data 1560 1561 def connection_lost(self, exc): 1562 if exc is None: 1563 self.fut.set_result(None) 1564 else: 1565 self.fut.set_exception(exc) 1566 1567 async def client(addr): 1568 ctx = self._create_client_ssl_context() 1569 tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) 1570 await eof.wait() 1571 tr.resume_reading() 1572 await pr.fut 1573 tr.close() 1574 assert extra == b'extra bytes' 1575 1576 with self.tcp_server(server) as srv: 1577 loop.run_until_complete(client(srv.addr)) 1578 1579 1580############################################################################### 1581# Socket Testing Utilities 1582############################################################################### 1583 1584 1585class TestSocketWrapper: 1586 1587 def __init__(self, sock): 1588 self.__sock = sock 1589 1590 def recv_all(self, n): 1591 buf = b'' 1592 while len(buf) < n: 1593 data = self.recv(n - len(buf)) 1594 if data == b'': 1595 raise ConnectionAbortedError 1596 buf += data 1597 return buf 1598 1599 def starttls(self, ssl_context, *, 1600 server_side=False, 1601 server_hostname=None, 1602 do_handshake_on_connect=True): 1603 1604 assert isinstance(ssl_context, ssl.SSLContext) 1605 1606 ssl_sock = ssl_context.wrap_socket( 1607 self.__sock, server_side=server_side, 1608 server_hostname=server_hostname, 1609 do_handshake_on_connect=do_handshake_on_connect) 1610 1611 if server_side: 1612 ssl_sock.do_handshake() 1613 1614 self.__sock.close() 1615 self.__sock = ssl_sock 1616 1617 def __getattr__(self, name): 1618 return getattr(self.__sock, name) 1619 1620 def __repr__(self): 1621 return '<{} {!r}>'.format(type(self).__name__, self.__sock) 1622 1623 1624class SocketThread(threading.Thread): 1625 1626 def stop(self): 1627 self._active = False 1628 self.join() 1629 1630 def __enter__(self): 1631 self.start() 1632 return self 1633 1634 def __exit__(self, *exc): 1635 self.stop() 1636 1637 1638class TestThreadedClient(SocketThread): 1639 1640 def __init__(self, test, sock, prog, timeout): 1641 threading.Thread.__init__(self, None, None, 'test-client') 1642 self.daemon = True 1643 1644 self._timeout = timeout 1645 self._sock = sock 1646 self._active = True 1647 self._prog = prog 1648 self._test = test 1649 1650 def run(self): 1651 try: 1652 self._prog(TestSocketWrapper(self._sock)) 1653 except (KeyboardInterrupt, SystemExit): 1654 raise 1655 except BaseException as ex: 1656 self._test._abort_socket_test(ex) 1657 1658 1659class TestThreadedServer(SocketThread): 1660 1661 def __init__(self, test, sock, prog, timeout, max_clients): 1662 threading.Thread.__init__(self, None, None, 'test-server') 1663 self.daemon = True 1664 1665 self._clients = 0 1666 self._finished_clients = 0 1667 self._max_clients = max_clients 1668 self._timeout = timeout 1669 self._sock = sock 1670 self._active = True 1671 1672 self._prog = prog 1673 1674 self._s1, self._s2 = socket.socketpair() 1675 self._s1.setblocking(False) 1676 1677 self._test = test 1678 1679 def stop(self): 1680 try: 1681 if self._s2 and self._s2.fileno() != -1: 1682 try: 1683 self._s2.send(b'stop') 1684 except OSError: 1685 pass 1686 finally: 1687 super().stop() 1688 1689 def run(self): 1690 try: 1691 with self._sock: 1692 self._sock.setblocking(False) 1693 self._run() 1694 finally: 1695 self._s1.close() 1696 self._s2.close() 1697 1698 def _run(self): 1699 while self._active: 1700 if self._clients >= self._max_clients: 1701 return 1702 1703 r, w, x = select.select( 1704 [self._sock, self._s1], [], [], self._timeout) 1705 1706 if self._s1 in r: 1707 return 1708 1709 if self._sock in r: 1710 try: 1711 conn, addr = self._sock.accept() 1712 except BlockingIOError: 1713 continue 1714 except socket.timeout: 1715 if not self._active: 1716 return 1717 else: 1718 raise 1719 else: 1720 self._clients += 1 1721 conn.settimeout(self._timeout) 1722 try: 1723 with conn: 1724 self._handle_client(conn) 1725 except (KeyboardInterrupt, SystemExit): 1726 raise 1727 except BaseException as ex: 1728 self._active = False 1729 try: 1730 raise 1731 finally: 1732 self._test._abort_socket_test(ex) 1733 1734 def _handle_client(self, sock): 1735 self._prog(TestSocketWrapper(sock)) 1736 1737 @property 1738 def addr(self): 1739 return self._sock.getsockname() 1740