1"""Tests for asyncio/sslproto.py.""" 2 3import logging 4import socket 5import unittest 6import weakref 7from test import support 8from unittest import mock 9try: 10 import ssl 11except ImportError: 12 ssl = None 13 14import asyncio 15from asyncio import log 16from asyncio import protocols 17from asyncio import sslproto 18from test.test_asyncio import utils as test_utils 19from test.test_asyncio import functional as func_tests 20 21 22def tearDownModule(): 23 asyncio.set_event_loop_policy(None) 24 25 26@unittest.skipIf(ssl is None, 'No ssl module') 27class SslProtoHandshakeTests(test_utils.TestCase): 28 29 def setUp(self): 30 super().setUp() 31 self.loop = asyncio.new_event_loop() 32 self.set_event_loop(self.loop) 33 34 def ssl_protocol(self, *, waiter=None, proto=None): 35 sslcontext = test_utils.dummy_ssl_context() 36 if proto is None: # app protocol 37 proto = asyncio.Protocol() 38 ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter, 39 ssl_handshake_timeout=0.1) 40 self.assertIs(ssl_proto._app_transport.get_protocol(), proto) 41 self.addCleanup(ssl_proto._app_transport.close) 42 return ssl_proto 43 44 def connection_made(self, ssl_proto, *, do_handshake=None): 45 transport = mock.Mock() 46 sslobj = mock.Mock() 47 # emulate reading decompressed data 48 sslobj.read.side_effect = ssl.SSLWantReadError 49 if do_handshake is not None: 50 sslobj.do_handshake = do_handshake 51 ssl_proto._sslobj = sslobj 52 ssl_proto.connection_made(transport) 53 return transport 54 55 def test_handshake_timeout_zero(self): 56 sslcontext = test_utils.dummy_ssl_context() 57 app_proto = mock.Mock() 58 waiter = mock.Mock() 59 with self.assertRaisesRegex(ValueError, 'a positive number'): 60 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, 61 ssl_handshake_timeout=0) 62 63 def test_handshake_timeout_negative(self): 64 sslcontext = test_utils.dummy_ssl_context() 65 app_proto = mock.Mock() 66 waiter = mock.Mock() 67 with self.assertRaisesRegex(ValueError, 'a positive number'): 68 sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter, 69 ssl_handshake_timeout=-10) 70 71 def test_eof_received_waiter(self): 72 waiter = self.loop.create_future() 73 ssl_proto = self.ssl_protocol(waiter=waiter) 74 self.connection_made( 75 ssl_proto, 76 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) 77 ) 78 ssl_proto.eof_received() 79 test_utils.run_briefly(self.loop) 80 self.assertIsInstance(waiter.exception(), ConnectionResetError) 81 82 def test_fatal_error_no_name_error(self): 83 # From issue #363. 84 # _fatal_error() generates a NameError if sslproto.py 85 # does not import base_events. 86 waiter = self.loop.create_future() 87 ssl_proto = self.ssl_protocol(waiter=waiter) 88 # Temporarily turn off error logging so as not to spoil test output. 89 log_level = log.logger.getEffectiveLevel() 90 log.logger.setLevel(logging.FATAL) 91 try: 92 ssl_proto._fatal_error(None) 93 finally: 94 # Restore error logging. 95 log.logger.setLevel(log_level) 96 97 def test_connection_lost(self): 98 # From issue #472. 99 # yield from waiter hang if lost_connection was called. 100 waiter = self.loop.create_future() 101 ssl_proto = self.ssl_protocol(waiter=waiter) 102 self.connection_made( 103 ssl_proto, 104 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) 105 ) 106 ssl_proto.connection_lost(ConnectionAbortedError) 107 test_utils.run_briefly(self.loop) 108 self.assertIsInstance(waiter.exception(), ConnectionAbortedError) 109 110 def test_close_during_handshake(self): 111 # bpo-29743 Closing transport during handshake process leaks socket 112 waiter = self.loop.create_future() 113 ssl_proto = self.ssl_protocol(waiter=waiter) 114 115 transport = self.connection_made( 116 ssl_proto, 117 do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError) 118 ) 119 test_utils.run_briefly(self.loop) 120 121 ssl_proto._app_transport.close() 122 self.assertTrue(transport.abort.called) 123 124 def test_get_extra_info_on_closed_connection(self): 125 waiter = self.loop.create_future() 126 ssl_proto = self.ssl_protocol(waiter=waiter) 127 self.assertIsNone(ssl_proto._get_extra_info('socket')) 128 default = object() 129 self.assertIs(ssl_proto._get_extra_info('socket', default), default) 130 self.connection_made(ssl_proto) 131 self.assertIsNotNone(ssl_proto._get_extra_info('socket')) 132 ssl_proto.connection_lost(None) 133 self.assertIsNone(ssl_proto._get_extra_info('socket')) 134 135 def test_set_new_app_protocol(self): 136 waiter = self.loop.create_future() 137 ssl_proto = self.ssl_protocol(waiter=waiter) 138 new_app_proto = asyncio.Protocol() 139 ssl_proto._app_transport.set_protocol(new_app_proto) 140 self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto) 141 self.assertIs(ssl_proto._app_protocol, new_app_proto) 142 143 def test_data_received_after_closing(self): 144 ssl_proto = self.ssl_protocol() 145 self.connection_made(ssl_proto) 146 transp = ssl_proto._app_transport 147 148 transp.close() 149 150 # should not raise 151 self.assertIsNone(ssl_proto.buffer_updated(5)) 152 153 def test_write_after_closing(self): 154 ssl_proto = self.ssl_protocol() 155 self.connection_made(ssl_proto) 156 transp = ssl_proto._app_transport 157 transp.close() 158 159 # should not raise 160 self.assertIsNone(transp.write(b'data')) 161 162 163############################################################################## 164# Start TLS Tests 165############################################################################## 166 167 168class BaseStartTLS(func_tests.FunctionalTestCaseMixin): 169 170 PAYLOAD_SIZE = 1024 * 100 171 TIMEOUT = support.LONG_TIMEOUT 172 173 def new_loop(self): 174 raise NotImplementedError 175 176 def test_buf_feed_data(self): 177 178 class Proto(asyncio.BufferedProtocol): 179 180 def __init__(self, bufsize, usemv): 181 self.buf = bytearray(bufsize) 182 self.mv = memoryview(self.buf) 183 self.data = b'' 184 self.usemv = usemv 185 186 def get_buffer(self, sizehint): 187 if self.usemv: 188 return self.mv 189 else: 190 return self.buf 191 192 def buffer_updated(self, nsize): 193 if self.usemv: 194 self.data += self.mv[:nsize] 195 else: 196 self.data += self.buf[:nsize] 197 198 for usemv in [False, True]: 199 proto = Proto(1, usemv) 200 protocols._feed_data_to_buffered_proto(proto, b'12345') 201 self.assertEqual(proto.data, b'12345') 202 203 proto = Proto(2, usemv) 204 protocols._feed_data_to_buffered_proto(proto, b'12345') 205 self.assertEqual(proto.data, b'12345') 206 207 proto = Proto(2, usemv) 208 protocols._feed_data_to_buffered_proto(proto, b'1234') 209 self.assertEqual(proto.data, b'1234') 210 211 proto = Proto(4, usemv) 212 protocols._feed_data_to_buffered_proto(proto, b'1234') 213 self.assertEqual(proto.data, b'1234') 214 215 proto = Proto(100, usemv) 216 protocols._feed_data_to_buffered_proto(proto, b'12345') 217 self.assertEqual(proto.data, b'12345') 218 219 proto = Proto(0, usemv) 220 with self.assertRaisesRegex(RuntimeError, 'empty buffer'): 221 protocols._feed_data_to_buffered_proto(proto, b'12345') 222 223 def test_start_tls_client_reg_proto_1(self): 224 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 225 226 server_context = test_utils.simple_server_sslcontext() 227 client_context = test_utils.simple_client_sslcontext() 228 229 def serve(sock): 230 sock.settimeout(self.TIMEOUT) 231 232 data = sock.recv_all(len(HELLO_MSG)) 233 self.assertEqual(len(data), len(HELLO_MSG)) 234 235 sock.start_tls(server_context, server_side=True) 236 237 sock.sendall(b'O') 238 data = sock.recv_all(len(HELLO_MSG)) 239 self.assertEqual(len(data), len(HELLO_MSG)) 240 241 sock.shutdown(socket.SHUT_RDWR) 242 sock.close() 243 244 class ClientProto(asyncio.Protocol): 245 def __init__(self, on_data, on_eof): 246 self.on_data = on_data 247 self.on_eof = on_eof 248 self.con_made_cnt = 0 249 250 def connection_made(proto, tr): 251 proto.con_made_cnt += 1 252 # Ensure connection_made gets called only once. 253 self.assertEqual(proto.con_made_cnt, 1) 254 255 def data_received(self, data): 256 self.on_data.set_result(data) 257 258 def eof_received(self): 259 self.on_eof.set_result(True) 260 261 async def client(addr): 262 await asyncio.sleep(0.5) 263 264 on_data = self.loop.create_future() 265 on_eof = self.loop.create_future() 266 267 tr, proto = await self.loop.create_connection( 268 lambda: ClientProto(on_data, on_eof), *addr) 269 270 tr.write(HELLO_MSG) 271 new_tr = await self.loop.start_tls(tr, proto, client_context) 272 273 self.assertEqual(await on_data, b'O') 274 new_tr.write(HELLO_MSG) 275 await on_eof 276 277 new_tr.close() 278 279 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 280 self.loop.run_until_complete( 281 asyncio.wait_for(client(srv.addr), 282 timeout=support.SHORT_TIMEOUT)) 283 284 # No garbage is left if SSL is closed uncleanly 285 client_context = weakref.ref(client_context) 286 support.gc_collect() 287 self.assertIsNone(client_context()) 288 289 def test_create_connection_memory_leak(self): 290 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 291 292 server_context = test_utils.simple_server_sslcontext() 293 client_context = test_utils.simple_client_sslcontext() 294 295 def serve(sock): 296 sock.settimeout(self.TIMEOUT) 297 298 sock.start_tls(server_context, server_side=True) 299 300 sock.sendall(b'O') 301 data = sock.recv_all(len(HELLO_MSG)) 302 self.assertEqual(len(data), len(HELLO_MSG)) 303 304 sock.shutdown(socket.SHUT_RDWR) 305 sock.close() 306 307 class ClientProto(asyncio.Protocol): 308 def __init__(self, on_data, on_eof): 309 self.on_data = on_data 310 self.on_eof = on_eof 311 self.con_made_cnt = 0 312 313 def connection_made(proto, tr): 314 # XXX: We assume user stores the transport in protocol 315 proto.tr = tr 316 proto.con_made_cnt += 1 317 # Ensure connection_made gets called only once. 318 self.assertEqual(proto.con_made_cnt, 1) 319 320 def data_received(self, data): 321 self.on_data.set_result(data) 322 323 def eof_received(self): 324 self.on_eof.set_result(True) 325 326 async def client(addr): 327 await asyncio.sleep(0.5) 328 329 on_data = self.loop.create_future() 330 on_eof = self.loop.create_future() 331 332 tr, proto = await self.loop.create_connection( 333 lambda: ClientProto(on_data, on_eof), *addr, 334 ssl=client_context) 335 336 self.assertEqual(await on_data, b'O') 337 tr.write(HELLO_MSG) 338 await on_eof 339 340 tr.close() 341 342 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 343 self.loop.run_until_complete( 344 asyncio.wait_for(client(srv.addr), 345 timeout=support.SHORT_TIMEOUT)) 346 347 # No garbage is left for SSL client from loop.create_connection, even 348 # if user stores the SSLTransport in corresponding protocol instance 349 client_context = weakref.ref(client_context) 350 support.gc_collect() 351 self.assertIsNone(client_context()) 352 353 def test_start_tls_client_buf_proto_1(self): 354 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 355 356 server_context = test_utils.simple_server_sslcontext() 357 client_context = test_utils.simple_client_sslcontext() 358 client_con_made_calls = 0 359 360 def serve(sock): 361 sock.settimeout(self.TIMEOUT) 362 363 data = sock.recv_all(len(HELLO_MSG)) 364 self.assertEqual(len(data), len(HELLO_MSG)) 365 366 sock.start_tls(server_context, server_side=True) 367 368 sock.sendall(b'O') 369 data = sock.recv_all(len(HELLO_MSG)) 370 self.assertEqual(len(data), len(HELLO_MSG)) 371 372 sock.sendall(b'2') 373 data = sock.recv_all(len(HELLO_MSG)) 374 self.assertEqual(len(data), len(HELLO_MSG)) 375 376 sock.shutdown(socket.SHUT_RDWR) 377 sock.close() 378 379 class ClientProtoFirst(asyncio.BufferedProtocol): 380 def __init__(self, on_data): 381 self.on_data = on_data 382 self.buf = bytearray(1) 383 384 def connection_made(self, tr): 385 nonlocal client_con_made_calls 386 client_con_made_calls += 1 387 388 def get_buffer(self, sizehint): 389 return self.buf 390 391 def buffer_updated(slf, nsize): 392 self.assertEqual(nsize, 1) 393 slf.on_data.set_result(bytes(slf.buf[:nsize])) 394 395 class ClientProtoSecond(asyncio.Protocol): 396 def __init__(self, on_data, on_eof): 397 self.on_data = on_data 398 self.on_eof = on_eof 399 self.con_made_cnt = 0 400 401 def connection_made(self, tr): 402 nonlocal client_con_made_calls 403 client_con_made_calls += 1 404 405 def data_received(self, data): 406 self.on_data.set_result(data) 407 408 def eof_received(self): 409 self.on_eof.set_result(True) 410 411 async def client(addr): 412 await asyncio.sleep(0.5) 413 414 on_data1 = self.loop.create_future() 415 on_data2 = self.loop.create_future() 416 on_eof = self.loop.create_future() 417 418 tr, proto = await self.loop.create_connection( 419 lambda: ClientProtoFirst(on_data1), *addr) 420 421 tr.write(HELLO_MSG) 422 new_tr = await self.loop.start_tls(tr, proto, client_context) 423 424 self.assertEqual(await on_data1, b'O') 425 new_tr.write(HELLO_MSG) 426 427 new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof)) 428 self.assertEqual(await on_data2, b'2') 429 new_tr.write(HELLO_MSG) 430 await on_eof 431 432 new_tr.close() 433 434 # connection_made() should be called only once -- when 435 # we establish connection for the first time. Start TLS 436 # doesn't call connection_made() on application protocols. 437 self.assertEqual(client_con_made_calls, 1) 438 439 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 440 self.loop.run_until_complete( 441 asyncio.wait_for(client(srv.addr), 442 timeout=self.TIMEOUT)) 443 444 def test_start_tls_slow_client_cancel(self): 445 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 446 447 client_context = test_utils.simple_client_sslcontext() 448 server_waits_on_handshake = self.loop.create_future() 449 450 def serve(sock): 451 sock.settimeout(self.TIMEOUT) 452 453 data = sock.recv_all(len(HELLO_MSG)) 454 self.assertEqual(len(data), len(HELLO_MSG)) 455 456 try: 457 self.loop.call_soon_threadsafe( 458 server_waits_on_handshake.set_result, None) 459 data = sock.recv_all(1024 * 1024) 460 except ConnectionAbortedError: 461 pass 462 finally: 463 sock.close() 464 465 class ClientProto(asyncio.Protocol): 466 def __init__(self, on_data, on_eof): 467 self.on_data = on_data 468 self.on_eof = on_eof 469 self.con_made_cnt = 0 470 471 def connection_made(proto, tr): 472 proto.con_made_cnt += 1 473 # Ensure connection_made gets called only once. 474 self.assertEqual(proto.con_made_cnt, 1) 475 476 def data_received(self, data): 477 self.on_data.set_result(data) 478 479 def eof_received(self): 480 self.on_eof.set_result(True) 481 482 async def client(addr): 483 await asyncio.sleep(0.5) 484 485 on_data = self.loop.create_future() 486 on_eof = self.loop.create_future() 487 488 tr, proto = await self.loop.create_connection( 489 lambda: ClientProto(on_data, on_eof), *addr) 490 491 tr.write(HELLO_MSG) 492 493 await server_waits_on_handshake 494 495 with self.assertRaises(asyncio.TimeoutError): 496 await asyncio.wait_for( 497 self.loop.start_tls(tr, proto, client_context), 498 0.5) 499 500 with self.tcp_server(serve, timeout=self.TIMEOUT) as srv: 501 self.loop.run_until_complete( 502 asyncio.wait_for(client(srv.addr), 503 timeout=support.SHORT_TIMEOUT)) 504 505 def test_start_tls_server_1(self): 506 HELLO_MSG = b'1' * self.PAYLOAD_SIZE 507 ANSWER = b'answer' 508 509 server_context = test_utils.simple_server_sslcontext() 510 client_context = test_utils.simple_client_sslcontext() 511 answer = None 512 513 def client(sock, addr): 514 nonlocal answer 515 sock.settimeout(self.TIMEOUT) 516 517 sock.connect(addr) 518 data = sock.recv_all(len(HELLO_MSG)) 519 self.assertEqual(len(data), len(HELLO_MSG)) 520 521 sock.start_tls(client_context) 522 sock.sendall(HELLO_MSG) 523 answer = sock.recv_all(len(ANSWER)) 524 sock.close() 525 526 class ServerProto(asyncio.Protocol): 527 def __init__(self, on_con, on_con_lost, on_got_hello): 528 self.on_con = on_con 529 self.on_con_lost = on_con_lost 530 self.on_got_hello = on_got_hello 531 self.data = b'' 532 self.transport = None 533 534 def connection_made(self, tr): 535 self.transport = tr 536 self.on_con.set_result(tr) 537 538 def replace_transport(self, tr): 539 self.transport = tr 540 541 def data_received(self, data): 542 self.data += data 543 if len(self.data) >= len(HELLO_MSG): 544 self.on_got_hello.set_result(None) 545 546 def connection_lost(self, exc): 547 self.transport = None 548 if exc is None: 549 self.on_con_lost.set_result(None) 550 else: 551 self.on_con_lost.set_exception(exc) 552 553 async def main(proto, on_con, on_con_lost, on_got_hello): 554 tr = await on_con 555 tr.write(HELLO_MSG) 556 557 self.assertEqual(proto.data, b'') 558 559 new_tr = await self.loop.start_tls( 560 tr, proto, server_context, 561 server_side=True, 562 ssl_handshake_timeout=self.TIMEOUT) 563 proto.replace_transport(new_tr) 564 565 await on_got_hello 566 new_tr.write(ANSWER) 567 568 await on_con_lost 569 self.assertEqual(proto.data, HELLO_MSG) 570 new_tr.close() 571 572 async def run_main(): 573 on_con = self.loop.create_future() 574 on_con_lost = self.loop.create_future() 575 on_got_hello = self.loop.create_future() 576 proto = ServerProto(on_con, on_con_lost, on_got_hello) 577 578 server = await self.loop.create_server( 579 lambda: proto, '127.0.0.1', 0) 580 addr = server.sockets[0].getsockname() 581 582 with self.tcp_client(lambda sock: client(sock, addr), 583 timeout=self.TIMEOUT): 584 await asyncio.wait_for( 585 main(proto, on_con, on_con_lost, on_got_hello), 586 timeout=self.TIMEOUT) 587 588 server.close() 589 await server.wait_closed() 590 self.assertEqual(answer, ANSWER) 591 592 self.loop.run_until_complete(run_main()) 593 594 def test_start_tls_wrong_args(self): 595 async def main(): 596 with self.assertRaisesRegex(TypeError, 'SSLContext, got'): 597 await self.loop.start_tls(None, None, None) 598 599 sslctx = test_utils.simple_server_sslcontext() 600 with self.assertRaisesRegex(TypeError, 'is not supported'): 601 await self.loop.start_tls(None, None, sslctx) 602 603 self.loop.run_until_complete(main()) 604 605 def test_handshake_timeout(self): 606 # bpo-29970: Check that a connection is aborted if handshake is not 607 # completed in timeout period, instead of remaining open indefinitely 608 client_sslctx = test_utils.simple_client_sslcontext() 609 610 messages = [] 611 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 612 613 server_side_aborted = False 614 615 def server(sock): 616 nonlocal server_side_aborted 617 try: 618 sock.recv_all(1024 * 1024) 619 except ConnectionAbortedError: 620 server_side_aborted = True 621 finally: 622 sock.close() 623 624 async def client(addr): 625 await asyncio.wait_for( 626 self.loop.create_connection( 627 asyncio.Protocol, 628 *addr, 629 ssl=client_sslctx, 630 server_hostname='', 631 ssl_handshake_timeout=support.SHORT_TIMEOUT), 632 0.5) 633 634 with self.tcp_server(server, 635 max_clients=1, 636 backlog=1) as srv: 637 638 with self.assertRaises(asyncio.TimeoutError): 639 self.loop.run_until_complete(client(srv.addr)) 640 641 self.assertTrue(server_side_aborted) 642 643 # Python issue #23197: cancelling a handshake must not raise an 644 # exception or log an error, even if the handshake failed 645 self.assertEqual(messages, []) 646 647 # The 10s handshake timeout should be cancelled to free related 648 # objects without really waiting for 10s 649 client_sslctx = weakref.ref(client_sslctx) 650 support.gc_collect() 651 self.assertIsNone(client_sslctx()) 652 653 def test_create_connection_ssl_slow_handshake(self): 654 client_sslctx = test_utils.simple_client_sslcontext() 655 656 messages = [] 657 self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) 658 659 def server(sock): 660 try: 661 sock.recv_all(1024 * 1024) 662 except ConnectionAbortedError: 663 pass 664 finally: 665 sock.close() 666 667 async def client(addr): 668 reader, writer = await asyncio.open_connection( 669 *addr, 670 ssl=client_sslctx, 671 server_hostname='', 672 ssl_handshake_timeout=1.0) 673 674 with self.tcp_server(server, 675 max_clients=1, 676 backlog=1) as srv: 677 678 with self.assertRaisesRegex( 679 ConnectionAbortedError, 680 r'SSL handshake.*is taking longer'): 681 682 self.loop.run_until_complete(client(srv.addr)) 683 684 self.assertEqual(messages, []) 685 686 def test_create_connection_ssl_failed_certificate(self): 687 self.loop.set_exception_handler(lambda loop, ctx: None) 688 689 sslctx = test_utils.simple_server_sslcontext() 690 client_sslctx = test_utils.simple_client_sslcontext( 691 disable_verify=False) 692 693 def server(sock): 694 try: 695 sock.start_tls( 696 sslctx, 697 server_side=True) 698 except ssl.SSLError: 699 pass 700 except OSError: 701 pass 702 finally: 703 sock.close() 704 705 async def client(addr): 706 reader, writer = await asyncio.open_connection( 707 *addr, 708 ssl=client_sslctx, 709 server_hostname='', 710 ssl_handshake_timeout=support.LOOPBACK_TIMEOUT) 711 712 with self.tcp_server(server, 713 max_clients=1, 714 backlog=1) as srv: 715 716 with self.assertRaises(ssl.SSLCertVerificationError): 717 self.loop.run_until_complete(client(srv.addr)) 718 719 def test_start_tls_client_corrupted_ssl(self): 720 self.loop.set_exception_handler(lambda loop, ctx: None) 721 722 sslctx = test_utils.simple_server_sslcontext() 723 client_sslctx = test_utils.simple_client_sslcontext() 724 725 def server(sock): 726 orig_sock = sock.dup() 727 try: 728 sock.start_tls( 729 sslctx, 730 server_side=True) 731 sock.sendall(b'A\n') 732 sock.recv_all(1) 733 orig_sock.send(b'please corrupt the SSL connection') 734 except ssl.SSLError: 735 pass 736 finally: 737 orig_sock.close() 738 sock.close() 739 740 async def client(addr): 741 reader, writer = await asyncio.open_connection( 742 *addr, 743 ssl=client_sslctx, 744 server_hostname='') 745 746 self.assertEqual(await reader.readline(), b'A\n') 747 writer.write(b'B') 748 with self.assertRaises(ssl.SSLError): 749 await reader.readline() 750 751 writer.close() 752 return 'OK' 753 754 with self.tcp_server(server, 755 max_clients=1, 756 backlog=1) as srv: 757 758 res = self.loop.run_until_complete(client(srv.addr)) 759 760 self.assertEqual(res, 'OK') 761 762 763@unittest.skipIf(ssl is None, 'No ssl module') 764class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase): 765 766 def new_loop(self): 767 return asyncio.SelectorEventLoop() 768 769 770@unittest.skipIf(ssl is None, 'No ssl module') 771@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') 772class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): 773 774 def new_loop(self): 775 return asyncio.ProactorEventLoop() 776 777 778if __name__ == '__main__': 779 unittest.main() 780