1"""Tests for asyncio/sslproto.py."""
2
3import logging
4import socket
5import unittest
6import weakref
7from test import support
8from unittest import mock
9try:
10    import ssl
11except ImportError:
12    ssl = None
13
14import asyncio
15from asyncio import log
16from asyncio import protocols
17from asyncio import sslproto
18from test.test_asyncio import utils as test_utils
19from test.test_asyncio import functional as func_tests
20
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26@unittest.skipIf(ssl is None, 'No ssl module')
27class SslProtoHandshakeTests(test_utils.TestCase):
28
29    def setUp(self):
30        super().setUp()
31        self.loop = asyncio.new_event_loop()
32        self.set_event_loop(self.loop)
33
34    def ssl_protocol(self, *, waiter=None, proto=None):
35        sslcontext = test_utils.dummy_ssl_context()
36        if proto is None:  # app protocol
37            proto = asyncio.Protocol()
38        ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
39                                         ssl_handshake_timeout=0.1)
40        self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
41        self.addCleanup(ssl_proto._app_transport.close)
42        return ssl_proto
43
44    def connection_made(self, ssl_proto, *, do_handshake=None):
45        transport = mock.Mock()
46        sslobj = mock.Mock()
47        # emulate reading decompressed data
48        sslobj.read.side_effect = ssl.SSLWantReadError
49        if do_handshake is not None:
50            sslobj.do_handshake = do_handshake
51        ssl_proto._sslobj = sslobj
52        ssl_proto.connection_made(transport)
53        return transport
54
55    def test_handshake_timeout_zero(self):
56        sslcontext = test_utils.dummy_ssl_context()
57        app_proto = mock.Mock()
58        waiter = mock.Mock()
59        with self.assertRaisesRegex(ValueError, 'a positive number'):
60            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
61                                 ssl_handshake_timeout=0)
62
63    def test_handshake_timeout_negative(self):
64        sslcontext = test_utils.dummy_ssl_context()
65        app_proto = mock.Mock()
66        waiter = mock.Mock()
67        with self.assertRaisesRegex(ValueError, 'a positive number'):
68            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
69                                 ssl_handshake_timeout=-10)
70
71    def test_eof_received_waiter(self):
72        waiter = self.loop.create_future()
73        ssl_proto = self.ssl_protocol(waiter=waiter)
74        self.connection_made(
75            ssl_proto,
76            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
77        )
78        ssl_proto.eof_received()
79        test_utils.run_briefly(self.loop)
80        self.assertIsInstance(waiter.exception(), ConnectionResetError)
81
82    def test_fatal_error_no_name_error(self):
83        # From issue #363.
84        # _fatal_error() generates a NameError if sslproto.py
85        # does not import base_events.
86        waiter = self.loop.create_future()
87        ssl_proto = self.ssl_protocol(waiter=waiter)
88        # Temporarily turn off error logging so as not to spoil test output.
89        log_level = log.logger.getEffectiveLevel()
90        log.logger.setLevel(logging.FATAL)
91        try:
92            ssl_proto._fatal_error(None)
93        finally:
94            # Restore error logging.
95            log.logger.setLevel(log_level)
96
97    def test_connection_lost(self):
98        # From issue #472.
99        # yield from waiter hang if lost_connection was called.
100        waiter = self.loop.create_future()
101        ssl_proto = self.ssl_protocol(waiter=waiter)
102        self.connection_made(
103            ssl_proto,
104            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
105        )
106        ssl_proto.connection_lost(ConnectionAbortedError)
107        test_utils.run_briefly(self.loop)
108        self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
109
110    def test_close_during_handshake(self):
111        # bpo-29743 Closing transport during handshake process leaks socket
112        waiter = self.loop.create_future()
113        ssl_proto = self.ssl_protocol(waiter=waiter)
114
115        transport = self.connection_made(
116            ssl_proto,
117            do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
118        )
119        test_utils.run_briefly(self.loop)
120
121        ssl_proto._app_transport.close()
122        self.assertTrue(transport.abort.called)
123
124    def test_get_extra_info_on_closed_connection(self):
125        waiter = self.loop.create_future()
126        ssl_proto = self.ssl_protocol(waiter=waiter)
127        self.assertIsNone(ssl_proto._get_extra_info('socket'))
128        default = object()
129        self.assertIs(ssl_proto._get_extra_info('socket', default), default)
130        self.connection_made(ssl_proto)
131        self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
132        ssl_proto.connection_lost(None)
133        self.assertIsNone(ssl_proto._get_extra_info('socket'))
134
135    def test_set_new_app_protocol(self):
136        waiter = self.loop.create_future()
137        ssl_proto = self.ssl_protocol(waiter=waiter)
138        new_app_proto = asyncio.Protocol()
139        ssl_proto._app_transport.set_protocol(new_app_proto)
140        self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
141        self.assertIs(ssl_proto._app_protocol, new_app_proto)
142
143    def test_data_received_after_closing(self):
144        ssl_proto = self.ssl_protocol()
145        self.connection_made(ssl_proto)
146        transp = ssl_proto._app_transport
147
148        transp.close()
149
150        # should not raise
151        self.assertIsNone(ssl_proto.buffer_updated(5))
152
153    def test_write_after_closing(self):
154        ssl_proto = self.ssl_protocol()
155        self.connection_made(ssl_proto)
156        transp = ssl_proto._app_transport
157        transp.close()
158
159        # should not raise
160        self.assertIsNone(transp.write(b'data'))
161
162
163##############################################################################
164# Start TLS Tests
165##############################################################################
166
167
168class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
169
170    PAYLOAD_SIZE = 1024 * 100
171    TIMEOUT = support.LONG_TIMEOUT
172
173    def new_loop(self):
174        raise NotImplementedError
175
176    def test_buf_feed_data(self):
177
178        class Proto(asyncio.BufferedProtocol):
179
180            def __init__(self, bufsize, usemv):
181                self.buf = bytearray(bufsize)
182                self.mv = memoryview(self.buf)
183                self.data = b''
184                self.usemv = usemv
185
186            def get_buffer(self, sizehint):
187                if self.usemv:
188                    return self.mv
189                else:
190                    return self.buf
191
192            def buffer_updated(self, nsize):
193                if self.usemv:
194                    self.data += self.mv[:nsize]
195                else:
196                    self.data += self.buf[:nsize]
197
198        for usemv in [False, True]:
199            proto = Proto(1, usemv)
200            protocols._feed_data_to_buffered_proto(proto, b'12345')
201            self.assertEqual(proto.data, b'12345')
202
203            proto = Proto(2, usemv)
204            protocols._feed_data_to_buffered_proto(proto, b'12345')
205            self.assertEqual(proto.data, b'12345')
206
207            proto = Proto(2, usemv)
208            protocols._feed_data_to_buffered_proto(proto, b'1234')
209            self.assertEqual(proto.data, b'1234')
210
211            proto = Proto(4, usemv)
212            protocols._feed_data_to_buffered_proto(proto, b'1234')
213            self.assertEqual(proto.data, b'1234')
214
215            proto = Proto(100, usemv)
216            protocols._feed_data_to_buffered_proto(proto, b'12345')
217            self.assertEqual(proto.data, b'12345')
218
219            proto = Proto(0, usemv)
220            with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
221                protocols._feed_data_to_buffered_proto(proto, b'12345')
222
223    def test_start_tls_client_reg_proto_1(self):
224        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
225
226        server_context = test_utils.simple_server_sslcontext()
227        client_context = test_utils.simple_client_sslcontext()
228
229        def serve(sock):
230            sock.settimeout(self.TIMEOUT)
231
232            data = sock.recv_all(len(HELLO_MSG))
233            self.assertEqual(len(data), len(HELLO_MSG))
234
235            sock.start_tls(server_context, server_side=True)
236
237            sock.sendall(b'O')
238            data = sock.recv_all(len(HELLO_MSG))
239            self.assertEqual(len(data), len(HELLO_MSG))
240
241            sock.shutdown(socket.SHUT_RDWR)
242            sock.close()
243
244        class ClientProto(asyncio.Protocol):
245            def __init__(self, on_data, on_eof):
246                self.on_data = on_data
247                self.on_eof = on_eof
248                self.con_made_cnt = 0
249
250            def connection_made(proto, tr):
251                proto.con_made_cnt += 1
252                # Ensure connection_made gets called only once.
253                self.assertEqual(proto.con_made_cnt, 1)
254
255            def data_received(self, data):
256                self.on_data.set_result(data)
257
258            def eof_received(self):
259                self.on_eof.set_result(True)
260
261        async def client(addr):
262            await asyncio.sleep(0.5)
263
264            on_data = self.loop.create_future()
265            on_eof = self.loop.create_future()
266
267            tr, proto = await self.loop.create_connection(
268                lambda: ClientProto(on_data, on_eof), *addr)
269
270            tr.write(HELLO_MSG)
271            new_tr = await self.loop.start_tls(tr, proto, client_context)
272
273            self.assertEqual(await on_data, b'O')
274            new_tr.write(HELLO_MSG)
275            await on_eof
276
277            new_tr.close()
278
279        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
280            self.loop.run_until_complete(
281                asyncio.wait_for(client(srv.addr),
282                                 timeout=support.SHORT_TIMEOUT))
283
284        # No garbage is left if SSL is closed uncleanly
285        client_context = weakref.ref(client_context)
286        support.gc_collect()
287        self.assertIsNone(client_context())
288
289    def test_create_connection_memory_leak(self):
290        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
291
292        server_context = test_utils.simple_server_sslcontext()
293        client_context = test_utils.simple_client_sslcontext()
294
295        def serve(sock):
296            sock.settimeout(self.TIMEOUT)
297
298            sock.start_tls(server_context, server_side=True)
299
300            sock.sendall(b'O')
301            data = sock.recv_all(len(HELLO_MSG))
302            self.assertEqual(len(data), len(HELLO_MSG))
303
304            sock.shutdown(socket.SHUT_RDWR)
305            sock.close()
306
307        class ClientProto(asyncio.Protocol):
308            def __init__(self, on_data, on_eof):
309                self.on_data = on_data
310                self.on_eof = on_eof
311                self.con_made_cnt = 0
312
313            def connection_made(proto, tr):
314                # XXX: We assume user stores the transport in protocol
315                proto.tr = tr
316                proto.con_made_cnt += 1
317                # Ensure connection_made gets called only once.
318                self.assertEqual(proto.con_made_cnt, 1)
319
320            def data_received(self, data):
321                self.on_data.set_result(data)
322
323            def eof_received(self):
324                self.on_eof.set_result(True)
325
326        async def client(addr):
327            await asyncio.sleep(0.5)
328
329            on_data = self.loop.create_future()
330            on_eof = self.loop.create_future()
331
332            tr, proto = await self.loop.create_connection(
333                lambda: ClientProto(on_data, on_eof), *addr,
334                ssl=client_context)
335
336            self.assertEqual(await on_data, b'O')
337            tr.write(HELLO_MSG)
338            await on_eof
339
340            tr.close()
341
342        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
343            self.loop.run_until_complete(
344                asyncio.wait_for(client(srv.addr),
345                                 timeout=support.SHORT_TIMEOUT))
346
347        # No garbage is left for SSL client from loop.create_connection, even
348        # if user stores the SSLTransport in corresponding protocol instance
349        client_context = weakref.ref(client_context)
350        support.gc_collect()
351        self.assertIsNone(client_context())
352
353    def test_start_tls_client_buf_proto_1(self):
354        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
355
356        server_context = test_utils.simple_server_sslcontext()
357        client_context = test_utils.simple_client_sslcontext()
358        client_con_made_calls = 0
359
360        def serve(sock):
361            sock.settimeout(self.TIMEOUT)
362
363            data = sock.recv_all(len(HELLO_MSG))
364            self.assertEqual(len(data), len(HELLO_MSG))
365
366            sock.start_tls(server_context, server_side=True)
367
368            sock.sendall(b'O')
369            data = sock.recv_all(len(HELLO_MSG))
370            self.assertEqual(len(data), len(HELLO_MSG))
371
372            sock.sendall(b'2')
373            data = sock.recv_all(len(HELLO_MSG))
374            self.assertEqual(len(data), len(HELLO_MSG))
375
376            sock.shutdown(socket.SHUT_RDWR)
377            sock.close()
378
379        class ClientProtoFirst(asyncio.BufferedProtocol):
380            def __init__(self, on_data):
381                self.on_data = on_data
382                self.buf = bytearray(1)
383
384            def connection_made(self, tr):
385                nonlocal client_con_made_calls
386                client_con_made_calls += 1
387
388            def get_buffer(self, sizehint):
389                return self.buf
390
391            def buffer_updated(slf, nsize):
392                self.assertEqual(nsize, 1)
393                slf.on_data.set_result(bytes(slf.buf[:nsize]))
394
395        class ClientProtoSecond(asyncio.Protocol):
396            def __init__(self, on_data, on_eof):
397                self.on_data = on_data
398                self.on_eof = on_eof
399                self.con_made_cnt = 0
400
401            def connection_made(self, tr):
402                nonlocal client_con_made_calls
403                client_con_made_calls += 1
404
405            def data_received(self, data):
406                self.on_data.set_result(data)
407
408            def eof_received(self):
409                self.on_eof.set_result(True)
410
411        async def client(addr):
412            await asyncio.sleep(0.5)
413
414            on_data1 = self.loop.create_future()
415            on_data2 = self.loop.create_future()
416            on_eof = self.loop.create_future()
417
418            tr, proto = await self.loop.create_connection(
419                lambda: ClientProtoFirst(on_data1), *addr)
420
421            tr.write(HELLO_MSG)
422            new_tr = await self.loop.start_tls(tr, proto, client_context)
423
424            self.assertEqual(await on_data1, b'O')
425            new_tr.write(HELLO_MSG)
426
427            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
428            self.assertEqual(await on_data2, b'2')
429            new_tr.write(HELLO_MSG)
430            await on_eof
431
432            new_tr.close()
433
434            # connection_made() should be called only once -- when
435            # we establish connection for the first time. Start TLS
436            # doesn't call connection_made() on application protocols.
437            self.assertEqual(client_con_made_calls, 1)
438
439        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
440            self.loop.run_until_complete(
441                asyncio.wait_for(client(srv.addr),
442                                 timeout=self.TIMEOUT))
443
444    def test_start_tls_slow_client_cancel(self):
445        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
446
447        client_context = test_utils.simple_client_sslcontext()
448        server_waits_on_handshake = self.loop.create_future()
449
450        def serve(sock):
451            sock.settimeout(self.TIMEOUT)
452
453            data = sock.recv_all(len(HELLO_MSG))
454            self.assertEqual(len(data), len(HELLO_MSG))
455
456            try:
457                self.loop.call_soon_threadsafe(
458                    server_waits_on_handshake.set_result, None)
459                data = sock.recv_all(1024 * 1024)
460            except ConnectionAbortedError:
461                pass
462            finally:
463                sock.close()
464
465        class ClientProto(asyncio.Protocol):
466            def __init__(self, on_data, on_eof):
467                self.on_data = on_data
468                self.on_eof = on_eof
469                self.con_made_cnt = 0
470
471            def connection_made(proto, tr):
472                proto.con_made_cnt += 1
473                # Ensure connection_made gets called only once.
474                self.assertEqual(proto.con_made_cnt, 1)
475
476            def data_received(self, data):
477                self.on_data.set_result(data)
478
479            def eof_received(self):
480                self.on_eof.set_result(True)
481
482        async def client(addr):
483            await asyncio.sleep(0.5)
484
485            on_data = self.loop.create_future()
486            on_eof = self.loop.create_future()
487
488            tr, proto = await self.loop.create_connection(
489                lambda: ClientProto(on_data, on_eof), *addr)
490
491            tr.write(HELLO_MSG)
492
493            await server_waits_on_handshake
494
495            with self.assertRaises(asyncio.TimeoutError):
496                await asyncio.wait_for(
497                    self.loop.start_tls(tr, proto, client_context),
498                    0.5)
499
500        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
501            self.loop.run_until_complete(
502                asyncio.wait_for(client(srv.addr),
503                                 timeout=support.SHORT_TIMEOUT))
504
505    def test_start_tls_server_1(self):
506        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
507        ANSWER = b'answer'
508
509        server_context = test_utils.simple_server_sslcontext()
510        client_context = test_utils.simple_client_sslcontext()
511        answer = None
512
513        def client(sock, addr):
514            nonlocal answer
515            sock.settimeout(self.TIMEOUT)
516
517            sock.connect(addr)
518            data = sock.recv_all(len(HELLO_MSG))
519            self.assertEqual(len(data), len(HELLO_MSG))
520
521            sock.start_tls(client_context)
522            sock.sendall(HELLO_MSG)
523            answer = sock.recv_all(len(ANSWER))
524            sock.close()
525
526        class ServerProto(asyncio.Protocol):
527            def __init__(self, on_con, on_con_lost, on_got_hello):
528                self.on_con = on_con
529                self.on_con_lost = on_con_lost
530                self.on_got_hello = on_got_hello
531                self.data = b''
532                self.transport = None
533
534            def connection_made(self, tr):
535                self.transport = tr
536                self.on_con.set_result(tr)
537
538            def replace_transport(self, tr):
539                self.transport = tr
540
541            def data_received(self, data):
542                self.data += data
543                if len(self.data) >= len(HELLO_MSG):
544                    self.on_got_hello.set_result(None)
545
546            def connection_lost(self, exc):
547                self.transport = None
548                if exc is None:
549                    self.on_con_lost.set_result(None)
550                else:
551                    self.on_con_lost.set_exception(exc)
552
553        async def main(proto, on_con, on_con_lost, on_got_hello):
554            tr = await on_con
555            tr.write(HELLO_MSG)
556
557            self.assertEqual(proto.data, b'')
558
559            new_tr = await self.loop.start_tls(
560                tr, proto, server_context,
561                server_side=True,
562                ssl_handshake_timeout=self.TIMEOUT)
563            proto.replace_transport(new_tr)
564
565            await on_got_hello
566            new_tr.write(ANSWER)
567
568            await on_con_lost
569            self.assertEqual(proto.data, HELLO_MSG)
570            new_tr.close()
571
572        async def run_main():
573            on_con = self.loop.create_future()
574            on_con_lost = self.loop.create_future()
575            on_got_hello = self.loop.create_future()
576            proto = ServerProto(on_con, on_con_lost, on_got_hello)
577
578            server = await self.loop.create_server(
579                lambda: proto, '127.0.0.1', 0)
580            addr = server.sockets[0].getsockname()
581
582            with self.tcp_client(lambda sock: client(sock, addr),
583                                 timeout=self.TIMEOUT):
584                await asyncio.wait_for(
585                    main(proto, on_con, on_con_lost, on_got_hello),
586                    timeout=self.TIMEOUT)
587
588            server.close()
589            await server.wait_closed()
590            self.assertEqual(answer, ANSWER)
591
592        self.loop.run_until_complete(run_main())
593
594    def test_start_tls_wrong_args(self):
595        async def main():
596            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
597                await self.loop.start_tls(None, None, None)
598
599            sslctx = test_utils.simple_server_sslcontext()
600            with self.assertRaisesRegex(TypeError, 'is not supported'):
601                await self.loop.start_tls(None, None, sslctx)
602
603        self.loop.run_until_complete(main())
604
605    def test_handshake_timeout(self):
606        # bpo-29970: Check that a connection is aborted if handshake is not
607        # completed in timeout period, instead of remaining open indefinitely
608        client_sslctx = test_utils.simple_client_sslcontext()
609
610        messages = []
611        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
612
613        server_side_aborted = False
614
615        def server(sock):
616            nonlocal server_side_aborted
617            try:
618                sock.recv_all(1024 * 1024)
619            except ConnectionAbortedError:
620                server_side_aborted = True
621            finally:
622                sock.close()
623
624        async def client(addr):
625            await asyncio.wait_for(
626                self.loop.create_connection(
627                    asyncio.Protocol,
628                    *addr,
629                    ssl=client_sslctx,
630                    server_hostname='',
631                    ssl_handshake_timeout=support.SHORT_TIMEOUT),
632                0.5)
633
634        with self.tcp_server(server,
635                             max_clients=1,
636                             backlog=1) as srv:
637
638            with self.assertRaises(asyncio.TimeoutError):
639                self.loop.run_until_complete(client(srv.addr))
640
641        self.assertTrue(server_side_aborted)
642
643        # Python issue #23197: cancelling a handshake must not raise an
644        # exception or log an error, even if the handshake failed
645        self.assertEqual(messages, [])
646
647        # The 10s handshake timeout should be cancelled to free related
648        # objects without really waiting for 10s
649        client_sslctx = weakref.ref(client_sslctx)
650        support.gc_collect()
651        self.assertIsNone(client_sslctx())
652
653    def test_create_connection_ssl_slow_handshake(self):
654        client_sslctx = test_utils.simple_client_sslcontext()
655
656        messages = []
657        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
658
659        def server(sock):
660            try:
661                sock.recv_all(1024 * 1024)
662            except ConnectionAbortedError:
663                pass
664            finally:
665                sock.close()
666
667        async def client(addr):
668            reader, writer = await asyncio.open_connection(
669                *addr,
670                ssl=client_sslctx,
671                server_hostname='',
672                ssl_handshake_timeout=1.0)
673
674        with self.tcp_server(server,
675                             max_clients=1,
676                             backlog=1) as srv:
677
678            with self.assertRaisesRegex(
679                    ConnectionAbortedError,
680                    r'SSL handshake.*is taking longer'):
681
682                self.loop.run_until_complete(client(srv.addr))
683
684        self.assertEqual(messages, [])
685
686    def test_create_connection_ssl_failed_certificate(self):
687        self.loop.set_exception_handler(lambda loop, ctx: None)
688
689        sslctx = test_utils.simple_server_sslcontext()
690        client_sslctx = test_utils.simple_client_sslcontext(
691            disable_verify=False)
692
693        def server(sock):
694            try:
695                sock.start_tls(
696                    sslctx,
697                    server_side=True)
698            except ssl.SSLError:
699                pass
700            except OSError:
701                pass
702            finally:
703                sock.close()
704
705        async def client(addr):
706            reader, writer = await asyncio.open_connection(
707                *addr,
708                ssl=client_sslctx,
709                server_hostname='',
710                ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
711
712        with self.tcp_server(server,
713                             max_clients=1,
714                             backlog=1) as srv:
715
716            with self.assertRaises(ssl.SSLCertVerificationError):
717                self.loop.run_until_complete(client(srv.addr))
718
719    def test_start_tls_client_corrupted_ssl(self):
720        self.loop.set_exception_handler(lambda loop, ctx: None)
721
722        sslctx = test_utils.simple_server_sslcontext()
723        client_sslctx = test_utils.simple_client_sslcontext()
724
725        def server(sock):
726            orig_sock = sock.dup()
727            try:
728                sock.start_tls(
729                    sslctx,
730                    server_side=True)
731                sock.sendall(b'A\n')
732                sock.recv_all(1)
733                orig_sock.send(b'please corrupt the SSL connection')
734            except ssl.SSLError:
735                pass
736            finally:
737                orig_sock.close()
738                sock.close()
739
740        async def client(addr):
741            reader, writer = await asyncio.open_connection(
742                *addr,
743                ssl=client_sslctx,
744                server_hostname='')
745
746            self.assertEqual(await reader.readline(), b'A\n')
747            writer.write(b'B')
748            with self.assertRaises(ssl.SSLError):
749                await reader.readline()
750
751            writer.close()
752            return 'OK'
753
754        with self.tcp_server(server,
755                             max_clients=1,
756                             backlog=1) as srv:
757
758            res = self.loop.run_until_complete(client(srv.addr))
759
760        self.assertEqual(res, 'OK')
761
762
763@unittest.skipIf(ssl is None, 'No ssl module')
764class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
765
766    def new_loop(self):
767        return asyncio.SelectorEventLoop()
768
769
770@unittest.skipIf(ssl is None, 'No ssl module')
771@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
772class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
773
774    def new_loop(self):
775        return asyncio.ProactorEventLoop()
776
777
778if __name__ == '__main__':
779    unittest.main()
780