1import asyncio
2import asyncio.sslproto
3import contextlib
4import gc
5import logging
6import select
7import socket
8import sys
9import tempfile
10import threading
11import time
12import weakref
13import unittest
14
15try:
16    import ssl
17except ImportError:
18    ssl = None
19
20from test import support
21from test.test_asyncio import utils as test_utils
22
23
24MACOS = (sys.platform == 'darwin')
25BUF_MULTIPLIER = 1024 if not MACOS else 64
26
27
28def tearDownModule():
29    asyncio.set_event_loop_policy(None)
30
31
32class MyBaseProto(asyncio.Protocol):
33    connected = None
34    done = None
35
36    def __init__(self, loop=None):
37        self.transport = None
38        self.state = 'INITIAL'
39        self.nbytes = 0
40        if loop is not None:
41            self.connected = asyncio.Future(loop=loop)
42            self.done = asyncio.Future(loop=loop)
43
44    def connection_made(self, transport):
45        self.transport = transport
46        assert self.state == 'INITIAL', self.state
47        self.state = 'CONNECTED'
48        if self.connected:
49            self.connected.set_result(None)
50
51    def data_received(self, data):
52        assert self.state == 'CONNECTED', self.state
53        self.nbytes += len(data)
54
55    def eof_received(self):
56        assert self.state == 'CONNECTED', self.state
57        self.state = 'EOF'
58
59    def connection_lost(self, exc):
60        assert self.state in ('CONNECTED', 'EOF'), self.state
61        self.state = 'CLOSED'
62        if self.done:
63            self.done.set_result(None)
64
65
66class MessageOutFilter(logging.Filter):
67    def __init__(self, msg):
68        self.msg = msg
69
70    def filter(self, record):
71        if self.msg in record.msg:
72            return False
73        return True
74
75
76@unittest.skipIf(ssl is None, 'No ssl module')
77class TestSSL(test_utils.TestCase):
78
79    PAYLOAD_SIZE = 1024 * 100
80    TIMEOUT = support.LONG_TIMEOUT
81
82    def setUp(self):
83        super().setUp()
84        self.loop = asyncio.new_event_loop()
85        self.set_event_loop(self.loop)
86        self.addCleanup(self.loop.close)
87
88    def tearDown(self):
89        # just in case if we have transport close callbacks
90        if not self.loop.is_closed():
91            test_utils.run_briefly(self.loop)
92
93        self.doCleanups()
94        support.gc_collect()
95        super().tearDown()
96
97    def tcp_server(self, server_prog, *,
98                   family=socket.AF_INET,
99                   addr=None,
100                   timeout=support.SHORT_TIMEOUT,
101                   backlog=1,
102                   max_clients=10):
103
104        if addr is None:
105            if family == getattr(socket, "AF_UNIX", None):
106                with tempfile.NamedTemporaryFile() as tmp:
107                    addr = tmp.name
108            else:
109                addr = ('127.0.0.1', 0)
110
111        sock = socket.socket(family, socket.SOCK_STREAM)
112
113        if timeout is None:
114            raise RuntimeError('timeout is required')
115        if timeout <= 0:
116            raise RuntimeError('only blocking sockets are supported')
117        sock.settimeout(timeout)
118
119        try:
120            sock.bind(addr)
121            sock.listen(backlog)
122        except OSError as ex:
123            sock.close()
124            raise ex
125
126        return TestThreadedServer(
127            self, sock, server_prog, timeout, max_clients)
128
129    def tcp_client(self, client_prog,
130                   family=socket.AF_INET,
131                   timeout=support.SHORT_TIMEOUT):
132
133        sock = socket.socket(family, socket.SOCK_STREAM)
134
135        if timeout is None:
136            raise RuntimeError('timeout is required')
137        if timeout <= 0:
138            raise RuntimeError('only blocking sockets are supported')
139        sock.settimeout(timeout)
140
141        return TestThreadedClient(
142            self, sock, client_prog, timeout)
143
144    def unix_server(self, *args, **kwargs):
145        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
146
147    def unix_client(self, *args, **kwargs):
148        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
149
150    def _create_server_ssl_context(self, certfile, keyfile=None):
151        sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
152        sslcontext.options |= ssl.OP_NO_SSLv2
153        sslcontext.load_cert_chain(certfile, keyfile)
154        return sslcontext
155
156    def _create_client_ssl_context(self, *, disable_verify=True):
157        sslcontext = ssl.create_default_context()
158        sslcontext.check_hostname = False
159        if disable_verify:
160            sslcontext.verify_mode = ssl.CERT_NONE
161        return sslcontext
162
163    @contextlib.contextmanager
164    def _silence_eof_received_warning(self):
165        # TODO This warning has to be fixed in asyncio.
166        logger = logging.getLogger('asyncio')
167        filter = MessageOutFilter('has no effect when using ssl')
168        logger.addFilter(filter)
169        try:
170            yield
171        finally:
172            logger.removeFilter(filter)
173
174    def _abort_socket_test(self, ex):
175        try:
176            self.loop.stop()
177        finally:
178            self.fail(ex)
179
180    def new_loop(self):
181        return asyncio.new_event_loop()
182
183    def new_policy(self):
184        return asyncio.DefaultEventLoopPolicy()
185
186    async def wait_closed(self, obj):
187        if not isinstance(obj, asyncio.StreamWriter):
188            return
189        try:
190            await obj.wait_closed()
191        except (BrokenPipeError, ConnectionError):
192            pass
193
194    def test_create_server_ssl_1(self):
195        CNT = 0           # number of clients that were successful
196        TOTAL_CNT = 25    # total number of clients that test will create
197        TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
198
199        A_DATA = b'A' * 1024 * BUF_MULTIPLIER
200        B_DATA = b'B' * 1024 * BUF_MULTIPLIER
201
202        sslctx = self._create_server_ssl_context(
203            test_utils.ONLYCERT, test_utils.ONLYKEY
204        )
205        client_sslctx = self._create_client_ssl_context()
206
207        clients = []
208
209        async def handle_client(reader, writer):
210            nonlocal CNT
211
212            data = await reader.readexactly(len(A_DATA))
213            self.assertEqual(data, A_DATA)
214            writer.write(b'OK')
215
216            data = await reader.readexactly(len(B_DATA))
217            self.assertEqual(data, B_DATA)
218            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
219
220            await writer.drain()
221            writer.close()
222
223            CNT += 1
224
225        async def test_client(addr):
226            fut = asyncio.Future()
227
228            def prog(sock):
229                try:
230                    sock.starttls(client_sslctx)
231                    sock.connect(addr)
232                    sock.send(A_DATA)
233
234                    data = sock.recv_all(2)
235                    self.assertEqual(data, b'OK')
236
237                    sock.send(B_DATA)
238                    data = sock.recv_all(4)
239                    self.assertEqual(data, b'SPAM')
240
241                    sock.close()
242
243                except Exception as ex:
244                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
245                else:
246                    self.loop.call_soon_threadsafe(fut.set_result, None)
247
248            client = self.tcp_client(prog)
249            client.start()
250            clients.append(client)
251
252            await fut
253
254        async def start_server():
255            extras = {}
256            extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
257
258            srv = await asyncio.start_server(
259                handle_client,
260                '127.0.0.1', 0,
261                family=socket.AF_INET,
262                ssl=sslctx,
263                **extras)
264
265            try:
266                srv_socks = srv.sockets
267                self.assertTrue(srv_socks)
268
269                addr = srv_socks[0].getsockname()
270
271                tasks = []
272                for _ in range(TOTAL_CNT):
273                    tasks.append(test_client(addr))
274
275                await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
276
277            finally:
278                self.loop.call_soon(srv.close)
279                await srv.wait_closed()
280
281        with self._silence_eof_received_warning():
282            self.loop.run_until_complete(start_server())
283
284        self.assertEqual(CNT, TOTAL_CNT)
285
286        for client in clients:
287            client.stop()
288
289    def test_create_connection_ssl_1(self):
290        self.loop.set_exception_handler(None)
291
292        CNT = 0
293        TOTAL_CNT = 25
294
295        A_DATA = b'A' * 1024 * BUF_MULTIPLIER
296        B_DATA = b'B' * 1024 * BUF_MULTIPLIER
297
298        sslctx = self._create_server_ssl_context(
299            test_utils.ONLYCERT,
300            test_utils.ONLYKEY
301        )
302        client_sslctx = self._create_client_ssl_context()
303
304        def server(sock):
305            sock.starttls(
306                sslctx,
307                server_side=True)
308
309            data = sock.recv_all(len(A_DATA))
310            self.assertEqual(data, A_DATA)
311            sock.send(b'OK')
312
313            data = sock.recv_all(len(B_DATA))
314            self.assertEqual(data, B_DATA)
315            sock.send(b'SPAM')
316
317            sock.close()
318
319        async def client(addr):
320            extras = {}
321            extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
322
323            reader, writer = await asyncio.open_connection(
324                *addr,
325                ssl=client_sslctx,
326                server_hostname='',
327                **extras)
328
329            writer.write(A_DATA)
330            self.assertEqual(await reader.readexactly(2), b'OK')
331
332            writer.write(B_DATA)
333            self.assertEqual(await reader.readexactly(4), b'SPAM')
334
335            nonlocal CNT
336            CNT += 1
337
338            writer.close()
339            await self.wait_closed(writer)
340
341        async def client_sock(addr):
342            sock = socket.socket()
343            sock.connect(addr)
344            reader, writer = await asyncio.open_connection(
345                sock=sock,
346                ssl=client_sslctx,
347                server_hostname='')
348
349            writer.write(A_DATA)
350            self.assertEqual(await reader.readexactly(2), b'OK')
351
352            writer.write(B_DATA)
353            self.assertEqual(await reader.readexactly(4), b'SPAM')
354
355            nonlocal CNT
356            CNT += 1
357
358            writer.close()
359            await self.wait_closed(writer)
360            sock.close()
361
362        def run(coro):
363            nonlocal CNT
364            CNT = 0
365
366            async def _gather(*tasks):
367                # trampoline
368                return await asyncio.gather(*tasks)
369
370            with self.tcp_server(server,
371                                 max_clients=TOTAL_CNT,
372                                 backlog=TOTAL_CNT) as srv:
373                tasks = []
374                for _ in range(TOTAL_CNT):
375                    tasks.append(coro(srv.addr))
376
377                self.loop.run_until_complete(_gather(*tasks))
378
379            self.assertEqual(CNT, TOTAL_CNT)
380
381        with self._silence_eof_received_warning():
382            run(client)
383
384        with self._silence_eof_received_warning():
385            run(client_sock)
386
387    def test_create_connection_ssl_slow_handshake(self):
388        client_sslctx = self._create_client_ssl_context()
389
390        # silence error logger
391        self.loop.set_exception_handler(lambda *args: None)
392
393        def server(sock):
394            try:
395                sock.recv_all(1024 * 1024)
396            except ConnectionAbortedError:
397                pass
398            finally:
399                sock.close()
400
401        async def client(addr):
402            reader, writer = await asyncio.open_connection(
403                *addr,
404                ssl=client_sslctx,
405                server_hostname='',
406                ssl_handshake_timeout=1.0)
407            writer.close()
408            await self.wait_closed(writer)
409
410        with self.tcp_server(server,
411                             max_clients=1,
412                             backlog=1) as srv:
413
414            with self.assertRaisesRegex(
415                    ConnectionAbortedError,
416                    r'SSL handshake.*is taking longer'):
417
418                self.loop.run_until_complete(client(srv.addr))
419
420    def test_create_connection_ssl_failed_certificate(self):
421        # silence error logger
422        self.loop.set_exception_handler(lambda *args: None)
423
424        sslctx = self._create_server_ssl_context(
425            test_utils.ONLYCERT,
426            test_utils.ONLYKEY
427        )
428        client_sslctx = self._create_client_ssl_context(disable_verify=False)
429
430        def server(sock):
431            try:
432                sock.starttls(
433                    sslctx,
434                    server_side=True)
435                sock.connect()
436            except (ssl.SSLError, OSError):
437                pass
438            finally:
439                sock.close()
440
441        async def client(addr):
442            reader, writer = await asyncio.open_connection(
443                *addr,
444                ssl=client_sslctx,
445                server_hostname='',
446                ssl_handshake_timeout=support.SHORT_TIMEOUT)
447            writer.close()
448            await self.wait_closed(writer)
449
450        with self.tcp_server(server,
451                             max_clients=1,
452                             backlog=1) as srv:
453
454            with self.assertRaises(ssl.SSLCertVerificationError):
455                self.loop.run_until_complete(client(srv.addr))
456
457    def test_ssl_handshake_timeout(self):
458        # bpo-29970: Check that a connection is aborted if handshake is not
459        # completed in timeout period, instead of remaining open indefinitely
460        client_sslctx = test_utils.simple_client_sslcontext()
461
462        # silence error logger
463        messages = []
464        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
465
466        server_side_aborted = False
467
468        def server(sock):
469            nonlocal server_side_aborted
470            try:
471                sock.recv_all(1024 * 1024)
472            except ConnectionAbortedError:
473                server_side_aborted = True
474            finally:
475                sock.close()
476
477        async def client(addr):
478            await asyncio.wait_for(
479                self.loop.create_connection(
480                    asyncio.Protocol,
481                    *addr,
482                    ssl=client_sslctx,
483                    server_hostname='',
484                    ssl_handshake_timeout=10.0),
485                0.5)
486
487        with self.tcp_server(server,
488                             max_clients=1,
489                             backlog=1) as srv:
490
491            with self.assertRaises(asyncio.TimeoutError):
492                self.loop.run_until_complete(client(srv.addr))
493
494        self.assertTrue(server_side_aborted)
495
496        # Python issue #23197: cancelling a handshake must not raise an
497        # exception or log an error, even if the handshake failed
498        self.assertEqual(messages, [])
499
500    def test_ssl_handshake_connection_lost(self):
501        # #246: make sure that no connection_lost() is called before
502        # connection_made() is called first
503
504        client_sslctx = test_utils.simple_client_sslcontext()
505
506        # silence error logger
507        self.loop.set_exception_handler(lambda loop, ctx: None)
508
509        connection_made_called = False
510        connection_lost_called = False
511
512        def server(sock):
513            sock.recv(1024)
514            # break the connection during handshake
515            sock.close()
516
517        class ClientProto(asyncio.Protocol):
518            def connection_made(self, transport):
519                nonlocal connection_made_called
520                connection_made_called = True
521
522            def connection_lost(self, exc):
523                nonlocal connection_lost_called
524                connection_lost_called = True
525
526        async def client(addr):
527            await self.loop.create_connection(
528                ClientProto,
529                *addr,
530                ssl=client_sslctx,
531                server_hostname=''),
532
533        with self.tcp_server(server,
534                             max_clients=1,
535                             backlog=1) as srv:
536
537            with self.assertRaises(ConnectionResetError):
538                self.loop.run_until_complete(client(srv.addr))
539
540        if connection_lost_called:
541            if connection_made_called:
542                self.fail("unexpected call to connection_lost()")
543            else:
544                self.fail("unexpected call to connection_lost() without"
545                          "calling connection_made()")
546        elif connection_made_called:
547            self.fail("unexpected call to connection_made()")
548
549    def test_ssl_connect_accepted_socket(self):
550        proto = ssl.PROTOCOL_TLS_SERVER
551        server_context = ssl.SSLContext(proto)
552        server_context.load_cert_chain(test_utils.ONLYCERT, test_utils.ONLYKEY)
553        if hasattr(server_context, 'check_hostname'):
554            server_context.check_hostname = False
555        server_context.verify_mode = ssl.CERT_NONE
556
557        client_context = ssl.SSLContext(proto)
558        if hasattr(server_context, 'check_hostname'):
559            client_context.check_hostname = False
560        client_context.verify_mode = ssl.CERT_NONE
561
562    def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
563        loop = self.loop
564
565        class MyProto(MyBaseProto):
566
567            def connection_lost(self, exc):
568                super().connection_lost(exc)
569                loop.call_soon(loop.stop)
570
571            def data_received(self, data):
572                super().data_received(data)
573                self.transport.write(expected_response)
574
575        lsock = socket.socket(socket.AF_INET)
576        lsock.bind(('127.0.0.1', 0))
577        lsock.listen(1)
578        addr = lsock.getsockname()
579
580        message = b'test data'
581        response = None
582        expected_response = b'roger'
583
584        def client():
585            nonlocal response
586            try:
587                csock = socket.socket(socket.AF_INET)
588                if client_ssl is not None:
589                    csock = client_ssl.wrap_socket(csock)
590                csock.connect(addr)
591                csock.sendall(message)
592                response = csock.recv(99)
593                csock.close()
594            except Exception as exc:
595                print(
596                    "Failure in client thread in test_connect_accepted_socket",
597                    exc)
598
599        thread = threading.Thread(target=client, daemon=True)
600        thread.start()
601
602        conn, _ = lsock.accept()
603        proto = MyProto(loop=loop)
604        proto.loop = loop
605
606        extras = {}
607        if server_ssl:
608            extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
609
610        f = loop.create_task(
611            loop.connect_accepted_socket(
612                (lambda: proto), conn, ssl=server_ssl,
613                **extras))
614        loop.run_forever()
615        conn.close()
616        lsock.close()
617
618        thread.join(1)
619        self.assertFalse(thread.is_alive())
620        self.assertEqual(proto.state, 'CLOSED')
621        self.assertEqual(proto.nbytes, len(message))
622        self.assertEqual(response, expected_response)
623        tr, _ = f.result()
624
625        if server_ssl:
626            self.assertIn('SSL', tr.__class__.__name__)
627
628        tr.close()
629        # let it close
630        self.loop.run_until_complete(asyncio.sleep(0.1))
631
632    def test_start_tls_client_corrupted_ssl(self):
633        self.loop.set_exception_handler(lambda loop, ctx: None)
634
635        sslctx = test_utils.simple_server_sslcontext()
636        client_sslctx = test_utils.simple_client_sslcontext()
637
638        def server(sock):
639            orig_sock = sock.dup()
640            try:
641                sock.starttls(
642                    sslctx,
643                    server_side=True)
644                sock.sendall(b'A\n')
645                sock.recv_all(1)
646                orig_sock.send(b'please corrupt the SSL connection')
647            except ssl.SSLError:
648                pass
649            finally:
650                sock.close()
651                orig_sock.close()
652
653        async def client(addr):
654            reader, writer = await asyncio.open_connection(
655                *addr,
656                ssl=client_sslctx,
657                server_hostname='')
658
659            self.assertEqual(await reader.readline(), b'A\n')
660            writer.write(b'B')
661            with self.assertRaises(ssl.SSLError):
662                await reader.readline()
663            writer.close()
664            try:
665                await self.wait_closed(writer)
666            except ssl.SSLError:
667                pass
668            return 'OK'
669
670        with self.tcp_server(server,
671                             max_clients=1,
672                             backlog=1) as srv:
673
674            res = self.loop.run_until_complete(client(srv.addr))
675
676        self.assertEqual(res, 'OK')
677
678    def test_start_tls_client_reg_proto_1(self):
679        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
680
681        server_context = test_utils.simple_server_sslcontext()
682        client_context = test_utils.simple_client_sslcontext()
683
684        def serve(sock):
685            sock.settimeout(self.TIMEOUT)
686
687            data = sock.recv_all(len(HELLO_MSG))
688            self.assertEqual(len(data), len(HELLO_MSG))
689
690            sock.starttls(server_context, server_side=True)
691
692            sock.sendall(b'O')
693            data = sock.recv_all(len(HELLO_MSG))
694            self.assertEqual(len(data), len(HELLO_MSG))
695
696            sock.unwrap()
697            sock.close()
698
699        class ClientProto(asyncio.Protocol):
700            def __init__(self, on_data, on_eof):
701                self.on_data = on_data
702                self.on_eof = on_eof
703                self.con_made_cnt = 0
704
705            def connection_made(proto, tr):
706                proto.con_made_cnt += 1
707                # Ensure connection_made gets called only once.
708                self.assertEqual(proto.con_made_cnt, 1)
709
710            def data_received(self, data):
711                self.on_data.set_result(data)
712
713            def eof_received(self):
714                self.on_eof.set_result(True)
715
716        async def client(addr):
717            await asyncio.sleep(0.5)
718
719            on_data = self.loop.create_future()
720            on_eof = self.loop.create_future()
721
722            tr, proto = await self.loop.create_connection(
723                lambda: ClientProto(on_data, on_eof), *addr)
724
725            tr.write(HELLO_MSG)
726            new_tr = await self.loop.start_tls(tr, proto, client_context)
727
728            self.assertEqual(await on_data, b'O')
729            new_tr.write(HELLO_MSG)
730            await on_eof
731
732            new_tr.close()
733
734        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
735            self.loop.run_until_complete(
736                asyncio.wait_for(client(srv.addr),
737                                 timeout=support.SHORT_TIMEOUT))
738
739    def test_create_connection_memory_leak(self):
740        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
741
742        server_context = self._create_server_ssl_context(
743            test_utils.ONLYCERT, test_utils.ONLYKEY)
744        client_context = self._create_client_ssl_context()
745
746        def serve(sock):
747            sock.settimeout(self.TIMEOUT)
748
749            sock.starttls(server_context, server_side=True)
750
751            sock.sendall(b'O')
752            data = sock.recv_all(len(HELLO_MSG))
753            self.assertEqual(len(data), len(HELLO_MSG))
754
755            sock.unwrap()
756            sock.close()
757
758        class ClientProto(asyncio.Protocol):
759            def __init__(self, on_data, on_eof):
760                self.on_data = on_data
761                self.on_eof = on_eof
762                self.con_made_cnt = 0
763
764            def connection_made(proto, tr):
765                # XXX: We assume user stores the transport in protocol
766                proto.tr = tr
767                proto.con_made_cnt += 1
768                # Ensure connection_made gets called only once.
769                self.assertEqual(proto.con_made_cnt, 1)
770
771            def data_received(self, data):
772                self.on_data.set_result(data)
773
774            def eof_received(self):
775                self.on_eof.set_result(True)
776
777        async def client(addr):
778            await asyncio.sleep(0.5)
779
780            on_data = self.loop.create_future()
781            on_eof = self.loop.create_future()
782
783            tr, proto = await self.loop.create_connection(
784                lambda: ClientProto(on_data, on_eof), *addr,
785                ssl=client_context)
786
787            self.assertEqual(await on_data, b'O')
788            tr.write(HELLO_MSG)
789            await on_eof
790
791            tr.close()
792
793        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
794            self.loop.run_until_complete(
795                asyncio.wait_for(client(srv.addr),
796                                 timeout=support.SHORT_TIMEOUT))
797
798        # No garbage is left for SSL client from loop.create_connection, even
799        # if user stores the SSLTransport in corresponding protocol instance
800        client_context = weakref.ref(client_context)
801        self.assertIsNone(client_context())
802
803    def test_start_tls_client_buf_proto_1(self):
804        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
805
806        server_context = test_utils.simple_server_sslcontext()
807        client_context = test_utils.simple_client_sslcontext()
808
809        client_con_made_calls = 0
810
811        def serve(sock):
812            sock.settimeout(self.TIMEOUT)
813
814            data = sock.recv_all(len(HELLO_MSG))
815            self.assertEqual(len(data), len(HELLO_MSG))
816
817            sock.starttls(server_context, server_side=True)
818
819            sock.sendall(b'O')
820            data = sock.recv_all(len(HELLO_MSG))
821            self.assertEqual(len(data), len(HELLO_MSG))
822
823            sock.sendall(b'2')
824            data = sock.recv_all(len(HELLO_MSG))
825            self.assertEqual(len(data), len(HELLO_MSG))
826
827            sock.unwrap()
828            sock.close()
829
830        class ClientProtoFirst(asyncio.BufferedProtocol):
831            def __init__(self, on_data):
832                self.on_data = on_data
833                self.buf = bytearray(1)
834
835            def connection_made(self, tr):
836                nonlocal client_con_made_calls
837                client_con_made_calls += 1
838
839            def get_buffer(self, sizehint):
840                return self.buf
841
842            def buffer_updated(self, nsize):
843                assert nsize == 1
844                self.on_data.set_result(bytes(self.buf[:nsize]))
845
846            def eof_received(self):
847                pass
848
849        class ClientProtoSecond(asyncio.Protocol):
850            def __init__(self, on_data, on_eof):
851                self.on_data = on_data
852                self.on_eof = on_eof
853                self.con_made_cnt = 0
854
855            def connection_made(self, tr):
856                nonlocal client_con_made_calls
857                client_con_made_calls += 1
858
859            def data_received(self, data):
860                self.on_data.set_result(data)
861
862            def eof_received(self):
863                self.on_eof.set_result(True)
864
865        async def client(addr):
866            await asyncio.sleep(0.5)
867
868            on_data1 = self.loop.create_future()
869            on_data2 = self.loop.create_future()
870            on_eof = self.loop.create_future()
871
872            tr, proto = await self.loop.create_connection(
873                lambda: ClientProtoFirst(on_data1), *addr)
874
875            tr.write(HELLO_MSG)
876            new_tr = await self.loop.start_tls(tr, proto, client_context)
877
878            self.assertEqual(await on_data1, b'O')
879            new_tr.write(HELLO_MSG)
880
881            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
882            self.assertEqual(await on_data2, b'2')
883            new_tr.write(HELLO_MSG)
884            await on_eof
885
886            new_tr.close()
887
888            # connection_made() should be called only once -- when
889            # we establish connection for the first time. Start TLS
890            # doesn't call connection_made() on application protocols.
891            self.assertEqual(client_con_made_calls, 1)
892
893        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
894            self.loop.run_until_complete(
895                asyncio.wait_for(client(srv.addr),
896                                 timeout=self.TIMEOUT))
897
898    def test_start_tls_slow_client_cancel(self):
899        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
900
901        client_context = test_utils.simple_client_sslcontext()
902        server_waits_on_handshake = self.loop.create_future()
903
904        def serve(sock):
905            sock.settimeout(self.TIMEOUT)
906
907            data = sock.recv_all(len(HELLO_MSG))
908            self.assertEqual(len(data), len(HELLO_MSG))
909
910            try:
911                self.loop.call_soon_threadsafe(
912                    server_waits_on_handshake.set_result, None)
913                data = sock.recv_all(1024 * 1024)
914            except ConnectionAbortedError:
915                pass
916            finally:
917                sock.close()
918
919        class ClientProto(asyncio.Protocol):
920            def __init__(self, on_data, on_eof):
921                self.on_data = on_data
922                self.on_eof = on_eof
923                self.con_made_cnt = 0
924
925            def connection_made(proto, tr):
926                proto.con_made_cnt += 1
927                # Ensure connection_made gets called only once.
928                self.assertEqual(proto.con_made_cnt, 1)
929
930            def data_received(self, data):
931                self.on_data.set_result(data)
932
933            def eof_received(self):
934                self.on_eof.set_result(True)
935
936        async def client(addr):
937            await asyncio.sleep(0.5)
938
939            on_data = self.loop.create_future()
940            on_eof = self.loop.create_future()
941
942            tr, proto = await self.loop.create_connection(
943                lambda: ClientProto(on_data, on_eof), *addr)
944
945            tr.write(HELLO_MSG)
946
947            await server_waits_on_handshake
948
949            with self.assertRaises(asyncio.TimeoutError):
950                await asyncio.wait_for(
951                    self.loop.start_tls(tr, proto, client_context),
952                    0.5)
953
954        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
955            self.loop.run_until_complete(
956                asyncio.wait_for(client(srv.addr),
957                                 timeout=support.SHORT_TIMEOUT))
958
959    def test_start_tls_server_1(self):
960        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
961
962        server_context = test_utils.simple_server_sslcontext()
963        client_context = test_utils.simple_client_sslcontext()
964
965        def client(sock, addr):
966            sock.settimeout(self.TIMEOUT)
967
968            sock.connect(addr)
969            data = sock.recv_all(len(HELLO_MSG))
970            self.assertEqual(len(data), len(HELLO_MSG))
971
972            sock.starttls(client_context)
973            sock.sendall(HELLO_MSG)
974
975            sock.unwrap()
976            sock.close()
977
978        class ServerProto(asyncio.Protocol):
979            def __init__(self, on_con, on_eof, on_con_lost):
980                self.on_con = on_con
981                self.on_eof = on_eof
982                self.on_con_lost = on_con_lost
983                self.data = b''
984
985            def connection_made(self, tr):
986                self.on_con.set_result(tr)
987
988            def data_received(self, data):
989                self.data += data
990
991            def eof_received(self):
992                self.on_eof.set_result(1)
993
994            def connection_lost(self, exc):
995                if exc is None:
996                    self.on_con_lost.set_result(None)
997                else:
998                    self.on_con_lost.set_exception(exc)
999
1000        async def main(proto, on_con, on_eof, on_con_lost):
1001            tr = await on_con
1002            tr.write(HELLO_MSG)
1003
1004            self.assertEqual(proto.data, b'')
1005
1006            new_tr = await self.loop.start_tls(
1007                tr, proto, server_context,
1008                server_side=True,
1009                ssl_handshake_timeout=self.TIMEOUT)
1010
1011            await on_eof
1012            await on_con_lost
1013            self.assertEqual(proto.data, HELLO_MSG)
1014            new_tr.close()
1015
1016        async def run_main():
1017            on_con = self.loop.create_future()
1018            on_eof = self.loop.create_future()
1019            on_con_lost = self.loop.create_future()
1020            proto = ServerProto(on_con, on_eof, on_con_lost)
1021
1022            server = await self.loop.create_server(
1023                lambda: proto, '127.0.0.1', 0)
1024            addr = server.sockets[0].getsockname()
1025
1026            with self.tcp_client(lambda sock: client(sock, addr),
1027                                 timeout=self.TIMEOUT):
1028                await asyncio.wait_for(
1029                    main(proto, on_con, on_eof, on_con_lost),
1030                    timeout=self.TIMEOUT)
1031
1032            server.close()
1033            await server.wait_closed()
1034
1035        self.loop.run_until_complete(run_main())
1036
1037    def test_create_server_ssl_over_ssl(self):
1038        CNT = 0           # number of clients that were successful
1039        TOTAL_CNT = 25    # total number of clients that test will create
1040        TIMEOUT = support.LONG_TIMEOUT  # timeout for this test
1041
1042        A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1043        B_DATA = b'B' * 1024 * BUF_MULTIPLIER
1044
1045        sslctx_1 = self._create_server_ssl_context(
1046            test_utils.ONLYCERT, test_utils.ONLYKEY)
1047        client_sslctx_1 = self._create_client_ssl_context()
1048        sslctx_2 = self._create_server_ssl_context(
1049            test_utils.ONLYCERT, test_utils.ONLYKEY)
1050        client_sslctx_2 = self._create_client_ssl_context()
1051
1052        clients = []
1053
1054        async def handle_client(reader, writer):
1055            nonlocal CNT
1056
1057            data = await reader.readexactly(len(A_DATA))
1058            self.assertEqual(data, A_DATA)
1059            writer.write(b'OK')
1060
1061            data = await reader.readexactly(len(B_DATA))
1062            self.assertEqual(data, B_DATA)
1063            writer.writelines([b'SP', bytearray(b'A'), memoryview(b'M')])
1064
1065            await writer.drain()
1066            writer.close()
1067
1068            CNT += 1
1069
1070        class ServerProtocol(asyncio.StreamReaderProtocol):
1071            def connection_made(self, transport):
1072                super_ = super()
1073                transport.pause_reading()
1074                fut = self._loop.create_task(self._loop.start_tls(
1075                    transport, self, sslctx_2, server_side=True))
1076
1077                def cb(_):
1078                    try:
1079                        tr = fut.result()
1080                    except Exception as ex:
1081                        super_.connection_lost(ex)
1082                    else:
1083                        super_.connection_made(tr)
1084                fut.add_done_callback(cb)
1085
1086        def server_protocol_factory():
1087            reader = asyncio.StreamReader()
1088            protocol = ServerProtocol(reader, handle_client)
1089            return protocol
1090
1091        async def test_client(addr):
1092            fut = asyncio.Future()
1093
1094            def prog(sock):
1095                try:
1096                    sock.connect(addr)
1097                    sock.starttls(client_sslctx_1)
1098
1099                    # because wrap_socket() doesn't work correctly on
1100                    # SSLSocket, we have to do the 2nd level SSL manually
1101                    incoming = ssl.MemoryBIO()
1102                    outgoing = ssl.MemoryBIO()
1103                    sslobj = client_sslctx_2.wrap_bio(incoming, outgoing)
1104
1105                    def do(func, *args):
1106                        while True:
1107                            try:
1108                                rv = func(*args)
1109                                break
1110                            except ssl.SSLWantReadError:
1111                                if outgoing.pending:
1112                                    sock.send(outgoing.read())
1113                                incoming.write(sock.recv(65536))
1114                        if outgoing.pending:
1115                            sock.send(outgoing.read())
1116                        return rv
1117
1118                    do(sslobj.do_handshake)
1119
1120                    do(sslobj.write, A_DATA)
1121                    data = do(sslobj.read, 2)
1122                    self.assertEqual(data, b'OK')
1123
1124                    do(sslobj.write, B_DATA)
1125                    data = b''
1126                    while True:
1127                        chunk = do(sslobj.read, 4)
1128                        if not chunk:
1129                            break
1130                        data += chunk
1131                    self.assertEqual(data, b'SPAM')
1132
1133                    do(sslobj.unwrap)
1134                    sock.close()
1135
1136                except Exception as ex:
1137                    self.loop.call_soon_threadsafe(fut.set_exception, ex)
1138                    sock.close()
1139                else:
1140                    self.loop.call_soon_threadsafe(fut.set_result, None)
1141
1142            client = self.tcp_client(prog)
1143            client.start()
1144            clients.append(client)
1145
1146            await fut
1147
1148        async def start_server():
1149            extras = {}
1150
1151            srv = await self.loop.create_server(
1152                server_protocol_factory,
1153                '127.0.0.1', 0,
1154                family=socket.AF_INET,
1155                ssl=sslctx_1,
1156                **extras)
1157
1158            try:
1159                srv_socks = srv.sockets
1160                self.assertTrue(srv_socks)
1161
1162                addr = srv_socks[0].getsockname()
1163
1164                tasks = []
1165                for _ in range(TOTAL_CNT):
1166                    tasks.append(test_client(addr))
1167
1168                await asyncio.wait_for(asyncio.gather(*tasks), TIMEOUT)
1169
1170            finally:
1171                self.loop.call_soon(srv.close)
1172                await srv.wait_closed()
1173
1174        with self._silence_eof_received_warning():
1175            self.loop.run_until_complete(start_server())
1176
1177        self.assertEqual(CNT, TOTAL_CNT)
1178
1179        for client in clients:
1180            client.stop()
1181
1182    def test_shutdown_cleanly(self):
1183        CNT = 0
1184        TOTAL_CNT = 25
1185
1186        A_DATA = b'A' * 1024 * BUF_MULTIPLIER
1187
1188        sslctx = self._create_server_ssl_context(
1189            test_utils.ONLYCERT, test_utils.ONLYKEY)
1190        client_sslctx = self._create_client_ssl_context()
1191
1192        def server(sock):
1193            sock.starttls(
1194                sslctx,
1195                server_side=True)
1196
1197            data = sock.recv_all(len(A_DATA))
1198            self.assertEqual(data, A_DATA)
1199            sock.send(b'OK')
1200
1201            sock.unwrap()
1202
1203            sock.close()
1204
1205        async def client(addr):
1206            extras = {}
1207            extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
1208
1209            reader, writer = await asyncio.open_connection(
1210                *addr,
1211                ssl=client_sslctx,
1212                server_hostname='',
1213                **extras)
1214
1215            writer.write(A_DATA)
1216            self.assertEqual(await reader.readexactly(2), b'OK')
1217
1218            self.assertEqual(await reader.read(), b'')
1219
1220            nonlocal CNT
1221            CNT += 1
1222
1223            writer.close()
1224            await self.wait_closed(writer)
1225
1226        def run(coro):
1227            nonlocal CNT
1228            CNT = 0
1229
1230            async def _gather(*tasks):
1231                return await asyncio.gather(*tasks)
1232
1233            with self.tcp_server(server,
1234                                 max_clients=TOTAL_CNT,
1235                                 backlog=TOTAL_CNT) as srv:
1236                tasks = []
1237                for _ in range(TOTAL_CNT):
1238                    tasks.append(coro(srv.addr))
1239
1240                self.loop.run_until_complete(
1241                    _gather(*tasks))
1242
1243            self.assertEqual(CNT, TOTAL_CNT)
1244
1245        with self._silence_eof_received_warning():
1246            run(client)
1247
1248    def test_flush_before_shutdown(self):
1249        CHUNK = 1024 * 128
1250        SIZE = 32
1251
1252        sslctx = self._create_server_ssl_context(
1253            test_utils.ONLYCERT, test_utils.ONLYKEY)
1254        client_sslctx = self._create_client_ssl_context()
1255
1256        future = None
1257
1258        def server(sock):
1259            sock.starttls(sslctx, server_side=True)
1260            self.assertEqual(sock.recv_all(4), b'ping')
1261            sock.send(b'pong')
1262            time.sleep(0.5)  # hopefully stuck the TCP buffer
1263            data = sock.recv_all(CHUNK * SIZE)
1264            self.assertEqual(len(data), CHUNK * SIZE)
1265            sock.close()
1266
1267        def run(meth):
1268            def wrapper(sock):
1269                try:
1270                    meth(sock)
1271                except Exception as ex:
1272                    self.loop.call_soon_threadsafe(future.set_exception, ex)
1273                else:
1274                    self.loop.call_soon_threadsafe(future.set_result, None)
1275            return wrapper
1276
1277        async def client(addr):
1278            nonlocal future
1279            future = self.loop.create_future()
1280            reader, writer = await asyncio.open_connection(
1281                *addr,
1282                ssl=client_sslctx,
1283                server_hostname='')
1284            sslprotocol = writer.transport._ssl_protocol
1285            writer.write(b'ping')
1286            data = await reader.readexactly(4)
1287            self.assertEqual(data, b'pong')
1288
1289            sslprotocol.pause_writing()
1290            for _ in range(SIZE):
1291                writer.write(b'x' * CHUNK)
1292
1293            writer.close()
1294            sslprotocol.resume_writing()
1295
1296            await self.wait_closed(writer)
1297            try:
1298                data = await reader.read()
1299                self.assertEqual(data, b'')
1300            except ConnectionResetError:
1301                pass
1302            await future
1303
1304        with self.tcp_server(run(server)) as srv:
1305            self.loop.run_until_complete(client(srv.addr))
1306
1307    def test_remote_shutdown_receives_trailing_data(self):
1308        CHUNK = 1024 * 128
1309        SIZE = 32
1310
1311        sslctx = self._create_server_ssl_context(
1312            test_utils.ONLYCERT,
1313            test_utils.ONLYKEY
1314        )
1315        client_sslctx = self._create_client_ssl_context()
1316        future = None
1317
1318        def server(sock):
1319            incoming = ssl.MemoryBIO()
1320            outgoing = ssl.MemoryBIO()
1321            sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)
1322
1323            while True:
1324                try:
1325                    sslobj.do_handshake()
1326                except ssl.SSLWantReadError:
1327                    if outgoing.pending:
1328                        sock.send(outgoing.read())
1329                    incoming.write(sock.recv(16384))
1330                else:
1331                    if outgoing.pending:
1332                        sock.send(outgoing.read())
1333                    break
1334
1335            while True:
1336                try:
1337                    data = sslobj.read(4)
1338                except ssl.SSLWantReadError:
1339                    incoming.write(sock.recv(16384))
1340                else:
1341                    break
1342
1343            self.assertEqual(data, b'ping')
1344            sslobj.write(b'pong')
1345            sock.send(outgoing.read())
1346
1347            time.sleep(0.2)  # wait for the peer to fill its backlog
1348
1349            # send close_notify but don't wait for response
1350            with self.assertRaises(ssl.SSLWantReadError):
1351                sslobj.unwrap()
1352            sock.send(outgoing.read())
1353
1354            # should receive all data
1355            data_len = 0
1356            while True:
1357                try:
1358                    chunk = len(sslobj.read(16384))
1359                    data_len += chunk
1360                except ssl.SSLWantReadError:
1361                    incoming.write(sock.recv(16384))
1362                except ssl.SSLZeroReturnError:
1363                    break
1364
1365            self.assertEqual(data_len, CHUNK * SIZE)
1366
1367            # verify that close_notify is received
1368            sslobj.unwrap()
1369
1370            sock.close()
1371
1372        def eof_server(sock):
1373            sock.starttls(sslctx, server_side=True)
1374            self.assertEqual(sock.recv_all(4), b'ping')
1375            sock.send(b'pong')
1376
1377            time.sleep(0.2)  # wait for the peer to fill its backlog
1378
1379            # send EOF
1380            sock.shutdown(socket.SHUT_WR)
1381
1382            # should receive all data
1383            data = sock.recv_all(CHUNK * SIZE)
1384            self.assertEqual(len(data), CHUNK * SIZE)
1385
1386            sock.close()
1387
1388        async def client(addr):
1389            nonlocal future
1390            future = self.loop.create_future()
1391
1392            reader, writer = await asyncio.open_connection(
1393                *addr,
1394                ssl=client_sslctx,
1395                server_hostname='')
1396            writer.write(b'ping')
1397            data = await reader.readexactly(4)
1398            self.assertEqual(data, b'pong')
1399
1400            # fill write backlog in a hacky way - renegotiation won't help
1401            for _ in range(SIZE):
1402                writer.transport._test__append_write_backlog(b'x' * CHUNK)
1403
1404            try:
1405                data = await reader.read()
1406                self.assertEqual(data, b'')
1407            except (BrokenPipeError, ConnectionResetError):
1408                pass
1409
1410            await future
1411
1412            writer.close()
1413            await self.wait_closed(writer)
1414
1415        def run(meth):
1416            def wrapper(sock):
1417                try:
1418                    meth(sock)
1419                except Exception as ex:
1420                    self.loop.call_soon_threadsafe(future.set_exception, ex)
1421                else:
1422                    self.loop.call_soon_threadsafe(future.set_result, None)
1423            return wrapper
1424
1425        with self.tcp_server(run(server)) as srv:
1426            self.loop.run_until_complete(client(srv.addr))
1427
1428        with self.tcp_server(run(eof_server)) as srv:
1429            self.loop.run_until_complete(client(srv.addr))
1430
1431    def test_connect_timeout_warning(self):
1432        s = socket.socket(socket.AF_INET)
1433        s.bind(('127.0.0.1', 0))
1434        addr = s.getsockname()
1435
1436        async def test():
1437            try:
1438                await asyncio.wait_for(
1439                    self.loop.create_connection(asyncio.Protocol,
1440                                                *addr, ssl=True),
1441                    0.1)
1442            except (ConnectionRefusedError, asyncio.TimeoutError):
1443                pass
1444            else:
1445                self.fail('TimeoutError is not raised')
1446
1447        with s:
1448            try:
1449                with self.assertWarns(ResourceWarning) as cm:
1450                    self.loop.run_until_complete(test())
1451                    gc.collect()
1452                    gc.collect()
1453                    gc.collect()
1454            except AssertionError as e:
1455                self.assertEqual(str(e), 'ResourceWarning not triggered')
1456            else:
1457                self.fail('Unexpected ResourceWarning: {}'.format(cm.warning))
1458
1459    def test_handshake_timeout_handler_leak(self):
1460        s = socket.socket(socket.AF_INET)
1461        s.bind(('127.0.0.1', 0))
1462        s.listen(1)
1463        addr = s.getsockname()
1464
1465        async def test(ctx):
1466            try:
1467                await asyncio.wait_for(
1468                    self.loop.create_connection(asyncio.Protocol, *addr,
1469                                                ssl=ctx),
1470                    0.1)
1471            except (ConnectionRefusedError, asyncio.TimeoutError):
1472                pass
1473            else:
1474                self.fail('TimeoutError is not raised')
1475
1476        with s:
1477            ctx = ssl.create_default_context()
1478            self.loop.run_until_complete(test(ctx))
1479            ctx = weakref.ref(ctx)
1480
1481        # SSLProtocol should be DECREF to 0
1482        self.assertIsNone(ctx())
1483
1484    def test_shutdown_timeout_handler_leak(self):
1485        loop = self.loop
1486
1487        def server(sock):
1488            sslctx = self._create_server_ssl_context(
1489                test_utils.ONLYCERT,
1490                test_utils.ONLYKEY
1491            )
1492            sock = sslctx.wrap_socket(sock, server_side=True)
1493            sock.recv(32)
1494            sock.close()
1495
1496        class Protocol(asyncio.Protocol):
1497            def __init__(self):
1498                self.fut = asyncio.Future(loop=loop)
1499
1500            def connection_lost(self, exc):
1501                self.fut.set_result(None)
1502
1503        async def client(addr, ctx):
1504            tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1505            tr.close()
1506            await pr.fut
1507
1508        with self.tcp_server(server) as srv:
1509            ctx = self._create_client_ssl_context()
1510            loop.run_until_complete(client(srv.addr, ctx))
1511            ctx = weakref.ref(ctx)
1512
1513        # asyncio has no shutdown timeout, but it ends up with a circular
1514        # reference loop - not ideal (introduces gc glitches), but at least
1515        # not leaking
1516        gc.collect()
1517        gc.collect()
1518        gc.collect()
1519
1520        # SSLProtocol should be DECREF to 0
1521        self.assertIsNone(ctx())
1522
1523    def test_shutdown_timeout_handler_not_set(self):
1524        loop = self.loop
1525        eof = asyncio.Event()
1526        extra = None
1527
1528        def server(sock):
1529            sslctx = self._create_server_ssl_context(
1530                test_utils.ONLYCERT,
1531                test_utils.ONLYKEY
1532            )
1533            sock = sslctx.wrap_socket(sock, server_side=True)
1534            sock.send(b'hello')
1535            assert sock.recv(1024) == b'world'
1536            sock.send(b'extra bytes')
1537            # sending EOF here
1538            sock.shutdown(socket.SHUT_WR)
1539            loop.call_soon_threadsafe(eof.set)
1540            # make sure we have enough time to reproduce the issue
1541            assert sock.recv(1024) == b''
1542            sock.close()
1543
1544        class Protocol(asyncio.Protocol):
1545            def __init__(self):
1546                self.fut = asyncio.Future(loop=loop)
1547                self.transport = None
1548
1549            def connection_made(self, transport):
1550                self.transport = transport
1551
1552            def data_received(self, data):
1553                if data == b'hello':
1554                    self.transport.write(b'world')
1555                    # pause reading would make incoming data stay in the sslobj
1556                    self.transport.pause_reading()
1557                else:
1558                    nonlocal extra
1559                    extra = data
1560
1561            def connection_lost(self, exc):
1562                if exc is None:
1563                    self.fut.set_result(None)
1564                else:
1565                    self.fut.set_exception(exc)
1566
1567        async def client(addr):
1568            ctx = self._create_client_ssl_context()
1569            tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
1570            await eof.wait()
1571            tr.resume_reading()
1572            await pr.fut
1573            tr.close()
1574            assert extra == b'extra bytes'
1575
1576        with self.tcp_server(server) as srv:
1577            loop.run_until_complete(client(srv.addr))
1578
1579
1580###############################################################################
1581# Socket Testing Utilities
1582###############################################################################
1583
1584
1585class TestSocketWrapper:
1586
1587    def __init__(self, sock):
1588        self.__sock = sock
1589
1590    def recv_all(self, n):
1591        buf = b''
1592        while len(buf) < n:
1593            data = self.recv(n - len(buf))
1594            if data == b'':
1595                raise ConnectionAbortedError
1596            buf += data
1597        return buf
1598
1599    def starttls(self, ssl_context, *,
1600                 server_side=False,
1601                 server_hostname=None,
1602                 do_handshake_on_connect=True):
1603
1604        assert isinstance(ssl_context, ssl.SSLContext)
1605
1606        ssl_sock = ssl_context.wrap_socket(
1607            self.__sock, server_side=server_side,
1608            server_hostname=server_hostname,
1609            do_handshake_on_connect=do_handshake_on_connect)
1610
1611        if server_side:
1612            ssl_sock.do_handshake()
1613
1614        self.__sock.close()
1615        self.__sock = ssl_sock
1616
1617    def __getattr__(self, name):
1618        return getattr(self.__sock, name)
1619
1620    def __repr__(self):
1621        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
1622
1623
1624class SocketThread(threading.Thread):
1625
1626    def stop(self):
1627        self._active = False
1628        self.join()
1629
1630    def __enter__(self):
1631        self.start()
1632        return self
1633
1634    def __exit__(self, *exc):
1635        self.stop()
1636
1637
1638class TestThreadedClient(SocketThread):
1639
1640    def __init__(self, test, sock, prog, timeout):
1641        threading.Thread.__init__(self, None, None, 'test-client')
1642        self.daemon = True
1643
1644        self._timeout = timeout
1645        self._sock = sock
1646        self._active = True
1647        self._prog = prog
1648        self._test = test
1649
1650    def run(self):
1651        try:
1652            self._prog(TestSocketWrapper(self._sock))
1653        except (KeyboardInterrupt, SystemExit):
1654            raise
1655        except BaseException as ex:
1656            self._test._abort_socket_test(ex)
1657
1658
1659class TestThreadedServer(SocketThread):
1660
1661    def __init__(self, test, sock, prog, timeout, max_clients):
1662        threading.Thread.__init__(self, None, None, 'test-server')
1663        self.daemon = True
1664
1665        self._clients = 0
1666        self._finished_clients = 0
1667        self._max_clients = max_clients
1668        self._timeout = timeout
1669        self._sock = sock
1670        self._active = True
1671
1672        self._prog = prog
1673
1674        self._s1, self._s2 = socket.socketpair()
1675        self._s1.setblocking(False)
1676
1677        self._test = test
1678
1679    def stop(self):
1680        try:
1681            if self._s2 and self._s2.fileno() != -1:
1682                try:
1683                    self._s2.send(b'stop')
1684                except OSError:
1685                    pass
1686        finally:
1687            super().stop()
1688
1689    def run(self):
1690        try:
1691            with self._sock:
1692                self._sock.setblocking(False)
1693                self._run()
1694        finally:
1695            self._s1.close()
1696            self._s2.close()
1697
1698    def _run(self):
1699        while self._active:
1700            if self._clients >= self._max_clients:
1701                return
1702
1703            r, w, x = select.select(
1704                [self._sock, self._s1], [], [], self._timeout)
1705
1706            if self._s1 in r:
1707                return
1708
1709            if self._sock in r:
1710                try:
1711                    conn, addr = self._sock.accept()
1712                except BlockingIOError:
1713                    continue
1714                except socket.timeout:
1715                    if not self._active:
1716                        return
1717                    else:
1718                        raise
1719                else:
1720                    self._clients += 1
1721                    conn.settimeout(self._timeout)
1722                    try:
1723                        with conn:
1724                            self._handle_client(conn)
1725                    except (KeyboardInterrupt, SystemExit):
1726                        raise
1727                    except BaseException as ex:
1728                        self._active = False
1729                        try:
1730                            raise
1731                        finally:
1732                            self._test._abort_socket_test(ex)
1733
1734    def _handle_client(self, sock):
1735        self._prog(TestSocketWrapper(sock))
1736
1737    @property
1738    def addr(self):
1739        return self._sock.getsockname()
1740