1"""Tests for streams.py."""
2
3import gc
4import os
5import queue
6import pickle
7import socket
8import sys
9import threading
10import unittest
11from unittest import mock
12from test.support import socket_helper
13try:
14    import ssl
15except ImportError:
16    ssl = None
17
18import asyncio
19from test.test_asyncio import utils as test_utils
20
21
22def tearDownModule():
23    asyncio.set_event_loop_policy(None)
24
25
26class StreamTests(test_utils.TestCase):
27
28    DATA = b'line1\nline2\nline3\n'
29
30    def setUp(self):
31        super().setUp()
32        self.loop = asyncio.new_event_loop()
33        self.set_event_loop(self.loop)
34
35    def tearDown(self):
36        # just in case if we have transport close callbacks
37        test_utils.run_briefly(self.loop)
38
39        self.loop.close()
40        gc.collect()
41        super().tearDown()
42
43    def _basetest_open_connection(self, open_connection_fut):
44        messages = []
45        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
46        reader, writer = self.loop.run_until_complete(open_connection_fut)
47        writer.write(b'GET / HTTP/1.0\r\n\r\n')
48        f = reader.readline()
49        data = self.loop.run_until_complete(f)
50        self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
51        f = reader.read()
52        data = self.loop.run_until_complete(f)
53        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
54        writer.close()
55        self.assertEqual(messages, [])
56
57    def test_open_connection(self):
58        with test_utils.run_test_server() as httpd:
59            conn_fut = asyncio.open_connection(*httpd.address)
60            self._basetest_open_connection(conn_fut)
61
62    @socket_helper.skip_unless_bind_unix_socket
63    def test_open_unix_connection(self):
64        with test_utils.run_test_unix_server() as httpd:
65            conn_fut = asyncio.open_unix_connection(httpd.address)
66            self._basetest_open_connection(conn_fut)
67
68    def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
69        messages = []
70        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
71        try:
72            reader, writer = self.loop.run_until_complete(open_connection_fut)
73        finally:
74            asyncio.set_event_loop(None)
75        writer.write(b'GET / HTTP/1.0\r\n\r\n')
76        f = reader.read()
77        data = self.loop.run_until_complete(f)
78        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
79
80        writer.close()
81        self.assertEqual(messages, [])
82
83    @unittest.skipIf(ssl is None, 'No ssl module')
84    def test_open_connection_no_loop_ssl(self):
85        with test_utils.run_test_server(use_ssl=True) as httpd:
86            conn_fut = asyncio.open_connection(
87                *httpd.address,
88                ssl=test_utils.dummy_ssl_context())
89
90            self._basetest_open_connection_no_loop_ssl(conn_fut)
91
92    @socket_helper.skip_unless_bind_unix_socket
93    @unittest.skipIf(ssl is None, 'No ssl module')
94    def test_open_unix_connection_no_loop_ssl(self):
95        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
96            conn_fut = asyncio.open_unix_connection(
97                httpd.address,
98                ssl=test_utils.dummy_ssl_context(),
99                server_hostname='',
100            )
101
102            self._basetest_open_connection_no_loop_ssl(conn_fut)
103
104    def _basetest_open_connection_error(self, open_connection_fut):
105        messages = []
106        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
107        reader, writer = self.loop.run_until_complete(open_connection_fut)
108        writer._protocol.connection_lost(ZeroDivisionError())
109        f = reader.read()
110        with self.assertRaises(ZeroDivisionError):
111            self.loop.run_until_complete(f)
112        writer.close()
113        test_utils.run_briefly(self.loop)
114        self.assertEqual(messages, [])
115
116    def test_open_connection_error(self):
117        with test_utils.run_test_server() as httpd:
118            conn_fut = asyncio.open_connection(*httpd.address)
119            self._basetest_open_connection_error(conn_fut)
120
121    @socket_helper.skip_unless_bind_unix_socket
122    def test_open_unix_connection_error(self):
123        with test_utils.run_test_unix_server() as httpd:
124            conn_fut = asyncio.open_unix_connection(httpd.address)
125            self._basetest_open_connection_error(conn_fut)
126
127    def test_feed_empty_data(self):
128        stream = asyncio.StreamReader(loop=self.loop)
129
130        stream.feed_data(b'')
131        self.assertEqual(b'', stream._buffer)
132
133    def test_feed_nonempty_data(self):
134        stream = asyncio.StreamReader(loop=self.loop)
135
136        stream.feed_data(self.DATA)
137        self.assertEqual(self.DATA, stream._buffer)
138
139    def test_read_zero(self):
140        # Read zero bytes.
141        stream = asyncio.StreamReader(loop=self.loop)
142        stream.feed_data(self.DATA)
143
144        data = self.loop.run_until_complete(stream.read(0))
145        self.assertEqual(b'', data)
146        self.assertEqual(self.DATA, stream._buffer)
147
148    def test_read(self):
149        # Read bytes.
150        stream = asyncio.StreamReader(loop=self.loop)
151        read_task = self.loop.create_task(stream.read(30))
152
153        def cb():
154            stream.feed_data(self.DATA)
155        self.loop.call_soon(cb)
156
157        data = self.loop.run_until_complete(read_task)
158        self.assertEqual(self.DATA, data)
159        self.assertEqual(b'', stream._buffer)
160
161    def test_read_line_breaks(self):
162        # Read bytes without line breaks.
163        stream = asyncio.StreamReader(loop=self.loop)
164        stream.feed_data(b'line1')
165        stream.feed_data(b'line2')
166
167        data = self.loop.run_until_complete(stream.read(5))
168
169        self.assertEqual(b'line1', data)
170        self.assertEqual(b'line2', stream._buffer)
171
172    def test_read_eof(self):
173        # Read bytes, stop at eof.
174        stream = asyncio.StreamReader(loop=self.loop)
175        read_task = self.loop.create_task(stream.read(1024))
176
177        def cb():
178            stream.feed_eof()
179        self.loop.call_soon(cb)
180
181        data = self.loop.run_until_complete(read_task)
182        self.assertEqual(b'', data)
183        self.assertEqual(b'', stream._buffer)
184
185    def test_read_until_eof(self):
186        # Read all bytes until eof.
187        stream = asyncio.StreamReader(loop=self.loop)
188        read_task = self.loop.create_task(stream.read(-1))
189
190        def cb():
191            stream.feed_data(b'chunk1\n')
192            stream.feed_data(b'chunk2')
193            stream.feed_eof()
194        self.loop.call_soon(cb)
195
196        data = self.loop.run_until_complete(read_task)
197
198        self.assertEqual(b'chunk1\nchunk2', data)
199        self.assertEqual(b'', stream._buffer)
200
201    def test_read_exception(self):
202        stream = asyncio.StreamReader(loop=self.loop)
203        stream.feed_data(b'line\n')
204
205        data = self.loop.run_until_complete(stream.read(2))
206        self.assertEqual(b'li', data)
207
208        stream.set_exception(ValueError())
209        self.assertRaises(
210            ValueError, self.loop.run_until_complete, stream.read(2))
211
212    def test_invalid_limit(self):
213        with self.assertRaisesRegex(ValueError, 'imit'):
214            asyncio.StreamReader(limit=0, loop=self.loop)
215
216        with self.assertRaisesRegex(ValueError, 'imit'):
217            asyncio.StreamReader(limit=-1, loop=self.loop)
218
219    def test_read_limit(self):
220        stream = asyncio.StreamReader(limit=3, loop=self.loop)
221        stream.feed_data(b'chunk')
222        data = self.loop.run_until_complete(stream.read(5))
223        self.assertEqual(b'chunk', data)
224        self.assertEqual(b'', stream._buffer)
225
226    def test_readline(self):
227        # Read one line. 'readline' will need to wait for the data
228        # to come from 'cb'
229        stream = asyncio.StreamReader(loop=self.loop)
230        stream.feed_data(b'chunk1 ')
231        read_task = self.loop.create_task(stream.readline())
232
233        def cb():
234            stream.feed_data(b'chunk2 ')
235            stream.feed_data(b'chunk3 ')
236            stream.feed_data(b'\n chunk4')
237        self.loop.call_soon(cb)
238
239        line = self.loop.run_until_complete(read_task)
240        self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
241        self.assertEqual(b' chunk4', stream._buffer)
242
243    def test_readline_limit_with_existing_data(self):
244        # Read one line. The data is in StreamReader's buffer
245        # before the event loop is run.
246
247        stream = asyncio.StreamReader(limit=3, loop=self.loop)
248        stream.feed_data(b'li')
249        stream.feed_data(b'ne1\nline2\n')
250
251        self.assertRaises(
252            ValueError, self.loop.run_until_complete, stream.readline())
253        # The buffer should contain the remaining data after exception
254        self.assertEqual(b'line2\n', stream._buffer)
255
256        stream = asyncio.StreamReader(limit=3, loop=self.loop)
257        stream.feed_data(b'li')
258        stream.feed_data(b'ne1')
259        stream.feed_data(b'li')
260
261        self.assertRaises(
262            ValueError, self.loop.run_until_complete, stream.readline())
263        # No b'\n' at the end. The 'limit' is set to 3. So before
264        # waiting for the new data in buffer, 'readline' will consume
265        # the entire buffer, and since the length of the consumed data
266        # is more than 3, it will raise a ValueError. The buffer is
267        # expected to be empty now.
268        self.assertEqual(b'', stream._buffer)
269
270    def test_at_eof(self):
271        stream = asyncio.StreamReader(loop=self.loop)
272        self.assertFalse(stream.at_eof())
273
274        stream.feed_data(b'some data\n')
275        self.assertFalse(stream.at_eof())
276
277        self.loop.run_until_complete(stream.readline())
278        self.assertFalse(stream.at_eof())
279
280        stream.feed_data(b'some data\n')
281        stream.feed_eof()
282        self.loop.run_until_complete(stream.readline())
283        self.assertTrue(stream.at_eof())
284
285    def test_readline_limit(self):
286        # Read one line. StreamReaders are fed with data after
287        # their 'readline' methods are called.
288
289        stream = asyncio.StreamReader(limit=7, loop=self.loop)
290        def cb():
291            stream.feed_data(b'chunk1')
292            stream.feed_data(b'chunk2')
293            stream.feed_data(b'chunk3\n')
294            stream.feed_eof()
295        self.loop.call_soon(cb)
296
297        self.assertRaises(
298            ValueError, self.loop.run_until_complete, stream.readline())
299        # The buffer had just one line of data, and after raising
300        # a ValueError it should be empty.
301        self.assertEqual(b'', stream._buffer)
302
303        stream = asyncio.StreamReader(limit=7, loop=self.loop)
304        def cb():
305            stream.feed_data(b'chunk1')
306            stream.feed_data(b'chunk2\n')
307            stream.feed_data(b'chunk3\n')
308            stream.feed_eof()
309        self.loop.call_soon(cb)
310
311        self.assertRaises(
312            ValueError, self.loop.run_until_complete, stream.readline())
313        self.assertEqual(b'chunk3\n', stream._buffer)
314
315        # check strictness of the limit
316        stream = asyncio.StreamReader(limit=7, loop=self.loop)
317        stream.feed_data(b'1234567\n')
318        line = self.loop.run_until_complete(stream.readline())
319        self.assertEqual(b'1234567\n', line)
320        self.assertEqual(b'', stream._buffer)
321
322        stream.feed_data(b'12345678\n')
323        with self.assertRaises(ValueError) as cm:
324            self.loop.run_until_complete(stream.readline())
325        self.assertEqual(b'', stream._buffer)
326
327        stream.feed_data(b'12345678')
328        with self.assertRaises(ValueError) as cm:
329            self.loop.run_until_complete(stream.readline())
330        self.assertEqual(b'', stream._buffer)
331
332    def test_readline_nolimit_nowait(self):
333        # All needed data for the first 'readline' call will be
334        # in the buffer.
335        stream = asyncio.StreamReader(loop=self.loop)
336        stream.feed_data(self.DATA[:6])
337        stream.feed_data(self.DATA[6:])
338
339        line = self.loop.run_until_complete(stream.readline())
340
341        self.assertEqual(b'line1\n', line)
342        self.assertEqual(b'line2\nline3\n', stream._buffer)
343
344    def test_readline_eof(self):
345        stream = asyncio.StreamReader(loop=self.loop)
346        stream.feed_data(b'some data')
347        stream.feed_eof()
348
349        line = self.loop.run_until_complete(stream.readline())
350        self.assertEqual(b'some data', line)
351
352    def test_readline_empty_eof(self):
353        stream = asyncio.StreamReader(loop=self.loop)
354        stream.feed_eof()
355
356        line = self.loop.run_until_complete(stream.readline())
357        self.assertEqual(b'', line)
358
359    def test_readline_read_byte_count(self):
360        stream = asyncio.StreamReader(loop=self.loop)
361        stream.feed_data(self.DATA)
362
363        self.loop.run_until_complete(stream.readline())
364
365        data = self.loop.run_until_complete(stream.read(7))
366
367        self.assertEqual(b'line2\nl', data)
368        self.assertEqual(b'ine3\n', stream._buffer)
369
370    def test_readline_exception(self):
371        stream = asyncio.StreamReader(loop=self.loop)
372        stream.feed_data(b'line\n')
373
374        data = self.loop.run_until_complete(stream.readline())
375        self.assertEqual(b'line\n', data)
376
377        stream.set_exception(ValueError())
378        self.assertRaises(
379            ValueError, self.loop.run_until_complete, stream.readline())
380        self.assertEqual(b'', stream._buffer)
381
382    def test_readuntil_separator(self):
383        stream = asyncio.StreamReader(loop=self.loop)
384        with self.assertRaisesRegex(ValueError, 'Separator should be'):
385            self.loop.run_until_complete(stream.readuntil(separator=b''))
386
387    def test_readuntil_multi_chunks(self):
388        stream = asyncio.StreamReader(loop=self.loop)
389
390        stream.feed_data(b'lineAAA')
391        data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
392        self.assertEqual(b'lineAAA', data)
393        self.assertEqual(b'', stream._buffer)
394
395        stream.feed_data(b'lineAAA')
396        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
397        self.assertEqual(b'lineAAA', data)
398        self.assertEqual(b'', stream._buffer)
399
400        stream.feed_data(b'lineAAAxxx')
401        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
402        self.assertEqual(b'lineAAA', data)
403        self.assertEqual(b'xxx', stream._buffer)
404
405    def test_readuntil_multi_chunks_1(self):
406        stream = asyncio.StreamReader(loop=self.loop)
407
408        stream.feed_data(b'QWEaa')
409        stream.feed_data(b'XYaa')
410        stream.feed_data(b'a')
411        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
412        self.assertEqual(b'QWEaaXYaaa', data)
413        self.assertEqual(b'', stream._buffer)
414
415        stream.feed_data(b'QWEaa')
416        stream.feed_data(b'XYa')
417        stream.feed_data(b'aa')
418        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
419        self.assertEqual(b'QWEaaXYaaa', data)
420        self.assertEqual(b'', stream._buffer)
421
422        stream.feed_data(b'aaa')
423        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
424        self.assertEqual(b'aaa', data)
425        self.assertEqual(b'', stream._buffer)
426
427        stream.feed_data(b'Xaaa')
428        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
429        self.assertEqual(b'Xaaa', data)
430        self.assertEqual(b'', stream._buffer)
431
432        stream.feed_data(b'XXX')
433        stream.feed_data(b'a')
434        stream.feed_data(b'a')
435        stream.feed_data(b'a')
436        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
437        self.assertEqual(b'XXXaaa', data)
438        self.assertEqual(b'', stream._buffer)
439
440    def test_readuntil_eof(self):
441        stream = asyncio.StreamReader(loop=self.loop)
442        data = b'some dataAA'
443        stream.feed_data(data)
444        stream.feed_eof()
445
446        with self.assertRaisesRegex(asyncio.IncompleteReadError,
447                                    'undefined expected bytes') as cm:
448            self.loop.run_until_complete(stream.readuntil(b'AAA'))
449        self.assertEqual(cm.exception.partial, data)
450        self.assertIsNone(cm.exception.expected)
451        self.assertEqual(b'', stream._buffer)
452
453    def test_readuntil_limit_found_sep(self):
454        stream = asyncio.StreamReader(loop=self.loop, limit=3)
455        stream.feed_data(b'some dataAA')
456        with self.assertRaisesRegex(asyncio.LimitOverrunError,
457                                    'not found') as cm:
458            self.loop.run_until_complete(stream.readuntil(b'AAA'))
459
460        self.assertEqual(b'some dataAA', stream._buffer)
461
462        stream.feed_data(b'A')
463        with self.assertRaisesRegex(asyncio.LimitOverrunError,
464                                    'is found') as cm:
465            self.loop.run_until_complete(stream.readuntil(b'AAA'))
466
467        self.assertEqual(b'some dataAAA', stream._buffer)
468
469    def test_readexactly_zero_or_less(self):
470        # Read exact number of bytes (zero or less).
471        stream = asyncio.StreamReader(loop=self.loop)
472        stream.feed_data(self.DATA)
473
474        data = self.loop.run_until_complete(stream.readexactly(0))
475        self.assertEqual(b'', data)
476        self.assertEqual(self.DATA, stream._buffer)
477
478        with self.assertRaisesRegex(ValueError, 'less than zero'):
479            self.loop.run_until_complete(stream.readexactly(-1))
480        self.assertEqual(self.DATA, stream._buffer)
481
482    def test_readexactly(self):
483        # Read exact number of bytes.
484        stream = asyncio.StreamReader(loop=self.loop)
485
486        n = 2 * len(self.DATA)
487        read_task = self.loop.create_task(stream.readexactly(n))
488
489        def cb():
490            stream.feed_data(self.DATA)
491            stream.feed_data(self.DATA)
492            stream.feed_data(self.DATA)
493        self.loop.call_soon(cb)
494
495        data = self.loop.run_until_complete(read_task)
496        self.assertEqual(self.DATA + self.DATA, data)
497        self.assertEqual(self.DATA, stream._buffer)
498
499    def test_readexactly_limit(self):
500        stream = asyncio.StreamReader(limit=3, loop=self.loop)
501        stream.feed_data(b'chunk')
502        data = self.loop.run_until_complete(stream.readexactly(5))
503        self.assertEqual(b'chunk', data)
504        self.assertEqual(b'', stream._buffer)
505
506    def test_readexactly_eof(self):
507        # Read exact number of bytes (eof).
508        stream = asyncio.StreamReader(loop=self.loop)
509        n = 2 * len(self.DATA)
510        read_task = self.loop.create_task(stream.readexactly(n))
511
512        def cb():
513            stream.feed_data(self.DATA)
514            stream.feed_eof()
515        self.loop.call_soon(cb)
516
517        with self.assertRaises(asyncio.IncompleteReadError) as cm:
518            self.loop.run_until_complete(read_task)
519        self.assertEqual(cm.exception.partial, self.DATA)
520        self.assertEqual(cm.exception.expected, n)
521        self.assertEqual(str(cm.exception),
522                         '18 bytes read on a total of 36 expected bytes')
523        self.assertEqual(b'', stream._buffer)
524
525    def test_readexactly_exception(self):
526        stream = asyncio.StreamReader(loop=self.loop)
527        stream.feed_data(b'line\n')
528
529        data = self.loop.run_until_complete(stream.readexactly(2))
530        self.assertEqual(b'li', data)
531
532        stream.set_exception(ValueError())
533        self.assertRaises(
534            ValueError, self.loop.run_until_complete, stream.readexactly(2))
535
536    def test_exception(self):
537        stream = asyncio.StreamReader(loop=self.loop)
538        self.assertIsNone(stream.exception())
539
540        exc = ValueError()
541        stream.set_exception(exc)
542        self.assertIs(stream.exception(), exc)
543
544    def test_exception_waiter(self):
545        stream = asyncio.StreamReader(loop=self.loop)
546
547        async def set_err():
548            stream.set_exception(ValueError())
549
550        t1 = self.loop.create_task(stream.readline())
551        t2 = self.loop.create_task(set_err())
552
553        self.loop.run_until_complete(asyncio.wait([t1, t2]))
554
555        self.assertRaises(ValueError, t1.result)
556
557    def test_exception_cancel(self):
558        stream = asyncio.StreamReader(loop=self.loop)
559
560        t = self.loop.create_task(stream.readline())
561        test_utils.run_briefly(self.loop)
562        t.cancel()
563        test_utils.run_briefly(self.loop)
564        # The following line fails if set_exception() isn't careful.
565        stream.set_exception(RuntimeError('message'))
566        test_utils.run_briefly(self.loop)
567        self.assertIs(stream._waiter, None)
568
569    def test_start_server(self):
570
571        class MyServer:
572
573            def __init__(self, loop):
574                self.server = None
575                self.loop = loop
576
577            async def handle_client(self, client_reader, client_writer):
578                data = await client_reader.readline()
579                client_writer.write(data)
580                await client_writer.drain()
581                client_writer.close()
582                await client_writer.wait_closed()
583
584            def start(self):
585                sock = socket.create_server(('127.0.0.1', 0))
586                self.server = self.loop.run_until_complete(
587                    asyncio.start_server(self.handle_client,
588                                         sock=sock))
589                return sock.getsockname()
590
591            def handle_client_callback(self, client_reader, client_writer):
592                self.loop.create_task(self.handle_client(client_reader,
593                                                         client_writer))
594
595            def start_callback(self):
596                sock = socket.create_server(('127.0.0.1', 0))
597                addr = sock.getsockname()
598                sock.close()
599                self.server = self.loop.run_until_complete(
600                    asyncio.start_server(self.handle_client_callback,
601                                         host=addr[0], port=addr[1]))
602                return addr
603
604            def stop(self):
605                if self.server is not None:
606                    self.server.close()
607                    self.loop.run_until_complete(self.server.wait_closed())
608                    self.server = None
609
610        async def client(addr):
611            reader, writer = await asyncio.open_connection(*addr)
612            # send a line
613            writer.write(b"hello world!\n")
614            # read it back
615            msgback = await reader.readline()
616            writer.close()
617            await writer.wait_closed()
618            return msgback
619
620        messages = []
621        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
622
623        # test the server variant with a coroutine as client handler
624        server = MyServer(self.loop)
625        addr = server.start()
626        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
627        server.stop()
628        self.assertEqual(msg, b"hello world!\n")
629
630        # test the server variant with a callback as client handler
631        server = MyServer(self.loop)
632        addr = server.start_callback()
633        msg = self.loop.run_until_complete(self.loop.create_task(client(addr)))
634        server.stop()
635        self.assertEqual(msg, b"hello world!\n")
636
637        self.assertEqual(messages, [])
638
639    @socket_helper.skip_unless_bind_unix_socket
640    def test_start_unix_server(self):
641
642        class MyServer:
643
644            def __init__(self, loop, path):
645                self.server = None
646                self.loop = loop
647                self.path = path
648
649            async def handle_client(self, client_reader, client_writer):
650                data = await client_reader.readline()
651                client_writer.write(data)
652                await client_writer.drain()
653                client_writer.close()
654                await client_writer.wait_closed()
655
656            def start(self):
657                self.server = self.loop.run_until_complete(
658                    asyncio.start_unix_server(self.handle_client,
659                                              path=self.path))
660
661            def handle_client_callback(self, client_reader, client_writer):
662                self.loop.create_task(self.handle_client(client_reader,
663                                                         client_writer))
664
665            def start_callback(self):
666                start = asyncio.start_unix_server(self.handle_client_callback,
667                                                  path=self.path)
668                self.server = self.loop.run_until_complete(start)
669
670            def stop(self):
671                if self.server is not None:
672                    self.server.close()
673                    self.loop.run_until_complete(self.server.wait_closed())
674                    self.server = None
675
676        async def client(path):
677            reader, writer = await asyncio.open_unix_connection(path)
678            # send a line
679            writer.write(b"hello world!\n")
680            # read it back
681            msgback = await reader.readline()
682            writer.close()
683            await writer.wait_closed()
684            return msgback
685
686        messages = []
687        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
688
689        # test the server variant with a coroutine as client handler
690        with test_utils.unix_socket_path() as path:
691            server = MyServer(self.loop, path)
692            server.start()
693            msg = self.loop.run_until_complete(
694                self.loop.create_task(client(path)))
695            server.stop()
696            self.assertEqual(msg, b"hello world!\n")
697
698        # test the server variant with a callback as client handler
699        with test_utils.unix_socket_path() as path:
700            server = MyServer(self.loop, path)
701            server.start_callback()
702            msg = self.loop.run_until_complete(
703                self.loop.create_task(client(path)))
704            server.stop()
705            self.assertEqual(msg, b"hello world!\n")
706
707        self.assertEqual(messages, [])
708
709    @unittest.skipIf(ssl is None, 'No ssl module')
710    def test_start_tls(self):
711
712        class MyServer:
713
714            def __init__(self, loop):
715                self.server = None
716                self.loop = loop
717
718            async def handle_client(self, client_reader, client_writer):
719                data1 = await client_reader.readline()
720                client_writer.write(data1)
721                await client_writer.drain()
722                assert client_writer.get_extra_info('sslcontext') is None
723                await client_writer.start_tls(
724                    test_utils.simple_server_sslcontext())
725                assert client_writer.get_extra_info('sslcontext') is not None
726                data2 = await client_reader.readline()
727                client_writer.write(data2)
728                await client_writer.drain()
729                client_writer.close()
730                await client_writer.wait_closed()
731
732            def start(self):
733                sock = socket.create_server(('127.0.0.1', 0))
734                self.server = self.loop.run_until_complete(
735                    asyncio.start_server(self.handle_client,
736                                         sock=sock))
737                return sock.getsockname()
738
739            def stop(self):
740                if self.server is not None:
741                    self.server.close()
742                    self.loop.run_until_complete(self.server.wait_closed())
743                    self.server = None
744
745        async def client(addr):
746            reader, writer = await asyncio.open_connection(*addr)
747            writer.write(b"hello world 1!\n")
748            await writer.drain()
749            msgback1 = await reader.readline()
750            assert writer.get_extra_info('sslcontext') is None
751            await writer.start_tls(test_utils.simple_client_sslcontext())
752            assert writer.get_extra_info('sslcontext') is not None
753            writer.write(b"hello world 2!\n")
754            await writer.drain()
755            msgback2 = await reader.readline()
756            writer.close()
757            await writer.wait_closed()
758            return msgback1, msgback2
759
760        messages = []
761        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
762
763        server = MyServer(self.loop)
764        addr = server.start()
765        msg1, msg2 = self.loop.run_until_complete(client(addr))
766        server.stop()
767
768        self.assertEqual(messages, [])
769        self.assertEqual(msg1, b"hello world 1!\n")
770        self.assertEqual(msg2, b"hello world 2!\n")
771
772    @unittest.skipIf(sys.platform == 'win32', "Don't have pipes")
773    def test_read_all_from_pipe_reader(self):
774        # See asyncio issue 168.  This test is derived from the example
775        # subprocess_attach_read_pipe.py, but we configure the
776        # StreamReader's limit so that twice it is less than the size
777        # of the data writer.  Also we must explicitly attach a child
778        # watcher to the event loop.
779
780        code = """\
781import os, sys
782fd = int(sys.argv[1])
783os.write(fd, b'data')
784os.close(fd)
785"""
786        rfd, wfd = os.pipe()
787        args = [sys.executable, '-c', code, str(wfd)]
788
789        pipe = open(rfd, 'rb', 0)
790        reader = asyncio.StreamReader(loop=self.loop, limit=1)
791        protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
792        transport, _ = self.loop.run_until_complete(
793            self.loop.connect_read_pipe(lambda: protocol, pipe))
794
795        watcher = asyncio.SafeChildWatcher()
796        watcher.attach_loop(self.loop)
797        try:
798            asyncio.set_child_watcher(watcher)
799            create = asyncio.create_subprocess_exec(
800                *args,
801                pass_fds={wfd},
802            )
803            proc = self.loop.run_until_complete(create)
804            self.loop.run_until_complete(proc.wait())
805        finally:
806            asyncio.set_child_watcher(None)
807
808        os.close(wfd)
809        data = self.loop.run_until_complete(reader.read(-1))
810        self.assertEqual(data, b'data')
811
812    def test_streamreader_constructor_without_loop(self):
813        with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
814            asyncio.StreamReader()
815
816    def test_streamreader_constructor_use_running_loop(self):
817        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
818        # retrieves the current loop if the loop parameter is not set
819        async def test():
820            return asyncio.StreamReader()
821
822        reader = self.loop.run_until_complete(test())
823        self.assertIs(reader._loop, self.loop)
824
825    def test_streamreader_constructor_use_global_loop(self):
826        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
827        # retrieves the current loop if the loop parameter is not set
828        # Deprecated in 3.10, undeprecated in 3.11.1
829        self.addCleanup(asyncio.set_event_loop, None)
830        asyncio.set_event_loop(self.loop)
831        reader = asyncio.StreamReader()
832        self.assertIs(reader._loop, self.loop)
833
834
835    def test_streamreaderprotocol_constructor_without_loop(self):
836        reader = mock.Mock()
837        with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
838            asyncio.StreamReaderProtocol(reader)
839
840    def test_streamreaderprotocol_constructor_use_running_loop(self):
841        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
842        # retrieves the current loop if the loop parameter is not set
843        reader = mock.Mock()
844        async def test():
845            return asyncio.StreamReaderProtocol(reader)
846        protocol = self.loop.run_until_complete(test())
847        self.assertIs(protocol._loop, self.loop)
848
849    def test_streamreaderprotocol_constructor_use_global_loop(self):
850        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
851        # retrieves the current loop if the loop parameter is not set
852        # Deprecated in 3.10, undeprecated in 3.11.1
853        self.addCleanup(asyncio.set_event_loop, None)
854        asyncio.set_event_loop(self.loop)
855        reader = mock.Mock()
856        protocol = asyncio.StreamReaderProtocol(reader)
857        self.assertIs(protocol._loop, self.loop)
858
859    def test_multiple_drain(self):
860        # See https://github.com/python/cpython/issues/74116
861        drained = 0
862
863        async def drainer(stream):
864            nonlocal drained
865            await stream._drain_helper()
866            drained += 1
867
868        async def main():
869            loop = asyncio.get_running_loop()
870            stream = asyncio.streams.FlowControlMixin(loop)
871            stream.pause_writing()
872            loop.call_later(0.1, stream.resume_writing)
873            await asyncio.gather(*[drainer(stream) for _ in range(10)])
874            self.assertEqual(drained, 10)
875
876        self.loop.run_until_complete(main())
877
878    def test_drain_raises(self):
879        # See http://bugs.python.org/issue25441
880
881        # This test should not use asyncio for the mock server; the
882        # whole point of the test is to test for a bug in drain()
883        # where it never gives up the event loop but the socket is
884        # closed on the  server side.
885
886        messages = []
887        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
888        q = queue.Queue()
889
890        def server():
891            # Runs in a separate thread.
892            with socket.create_server(('localhost', 0)) as sock:
893                addr = sock.getsockname()
894                q.put(addr)
895                clt, _ = sock.accept()
896                clt.close()
897
898        async def client(host, port):
899            reader, writer = await asyncio.open_connection(host, port)
900
901            while True:
902                writer.write(b"foo\n")
903                await writer.drain()
904
905        # Start the server thread and wait for it to be listening.
906        thread = threading.Thread(target=server)
907        thread.daemon = True
908        thread.start()
909        addr = q.get()
910
911        # Should not be stuck in an infinite loop.
912        with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
913                                BrokenPipeError)):
914            self.loop.run_until_complete(client(*addr))
915
916        # Clean up the thread.  (Only on success; on failure, it may
917        # be stuck in accept().)
918        thread.join()
919        self.assertEqual([], messages)
920
921    def test___repr__(self):
922        stream = asyncio.StreamReader(loop=self.loop)
923        self.assertEqual("<StreamReader>", repr(stream))
924
925    def test___repr__nondefault_limit(self):
926        stream = asyncio.StreamReader(loop=self.loop, limit=123)
927        self.assertEqual("<StreamReader limit=123>", repr(stream))
928
929    def test___repr__eof(self):
930        stream = asyncio.StreamReader(loop=self.loop)
931        stream.feed_eof()
932        self.assertEqual("<StreamReader eof>", repr(stream))
933
934    def test___repr__data(self):
935        stream = asyncio.StreamReader(loop=self.loop)
936        stream.feed_data(b'data')
937        self.assertEqual("<StreamReader 4 bytes>", repr(stream))
938
939    def test___repr__exception(self):
940        stream = asyncio.StreamReader(loop=self.loop)
941        exc = RuntimeError()
942        stream.set_exception(exc)
943        self.assertEqual("<StreamReader exception=RuntimeError()>",
944                         repr(stream))
945
946    def test___repr__waiter(self):
947        stream = asyncio.StreamReader(loop=self.loop)
948        stream._waiter = asyncio.Future(loop=self.loop)
949        self.assertRegex(
950            repr(stream),
951            r"<StreamReader waiter=<Future pending[\S ]*>>")
952        stream._waiter.set_result(None)
953        self.loop.run_until_complete(stream._waiter)
954        stream._waiter = None
955        self.assertEqual("<StreamReader>", repr(stream))
956
957    def test___repr__transport(self):
958        stream = asyncio.StreamReader(loop=self.loop)
959        stream._transport = mock.Mock()
960        stream._transport.__repr__ = mock.Mock()
961        stream._transport.__repr__.return_value = "<Transport>"
962        self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
963
964    def test_IncompleteReadError_pickleable(self):
965        e = asyncio.IncompleteReadError(b'abc', 10)
966        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
967            with self.subTest(pickle_protocol=proto):
968                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
969                self.assertEqual(str(e), str(e2))
970                self.assertEqual(e.partial, e2.partial)
971                self.assertEqual(e.expected, e2.expected)
972
973    def test_LimitOverrunError_pickleable(self):
974        e = asyncio.LimitOverrunError('message', 10)
975        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
976            with self.subTest(pickle_protocol=proto):
977                e2 = pickle.loads(pickle.dumps(e, protocol=proto))
978                self.assertEqual(str(e), str(e2))
979                self.assertEqual(e.consumed, e2.consumed)
980
981    def test_wait_closed_on_close(self):
982        with test_utils.run_test_server() as httpd:
983            rd, wr = self.loop.run_until_complete(
984                asyncio.open_connection(*httpd.address))
985
986            wr.write(b'GET / HTTP/1.0\r\n\r\n')
987            f = rd.readline()
988            data = self.loop.run_until_complete(f)
989            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
990            f = rd.read()
991            data = self.loop.run_until_complete(f)
992            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
993            self.assertFalse(wr.is_closing())
994            wr.close()
995            self.assertTrue(wr.is_closing())
996            self.loop.run_until_complete(wr.wait_closed())
997
998    def test_wait_closed_on_close_with_unread_data(self):
999        with test_utils.run_test_server() as httpd:
1000            rd, wr = self.loop.run_until_complete(
1001                asyncio.open_connection(*httpd.address))
1002
1003            wr.write(b'GET / HTTP/1.0\r\n\r\n')
1004            f = rd.readline()
1005            data = self.loop.run_until_complete(f)
1006            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1007            wr.close()
1008            self.loop.run_until_complete(wr.wait_closed())
1009
1010    def test_async_writer_api(self):
1011        async def inner(httpd):
1012            rd, wr = await asyncio.open_connection(*httpd.address)
1013
1014            wr.write(b'GET / HTTP/1.0\r\n\r\n')
1015            data = await rd.readline()
1016            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1017            data = await rd.read()
1018            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
1019            wr.close()
1020            await wr.wait_closed()
1021
1022        messages = []
1023        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1024
1025        with test_utils.run_test_server() as httpd:
1026            self.loop.run_until_complete(inner(httpd))
1027
1028        self.assertEqual(messages, [])
1029
1030    def test_async_writer_api_exception_after_close(self):
1031        async def inner(httpd):
1032            rd, wr = await asyncio.open_connection(*httpd.address)
1033
1034            wr.write(b'GET / HTTP/1.0\r\n\r\n')
1035            data = await rd.readline()
1036            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
1037            data = await rd.read()
1038            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
1039            wr.close()
1040            with self.assertRaises(ConnectionResetError):
1041                wr.write(b'data')
1042                await wr.drain()
1043
1044        messages = []
1045        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1046
1047        with test_utils.run_test_server() as httpd:
1048            self.loop.run_until_complete(inner(httpd))
1049
1050        self.assertEqual(messages, [])
1051
1052    def test_eof_feed_when_closing_writer(self):
1053        # See http://bugs.python.org/issue35065
1054        messages = []
1055        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
1056
1057        with test_utils.run_test_server() as httpd:
1058            rd, wr = self.loop.run_until_complete(
1059                    asyncio.open_connection(*httpd.address))
1060
1061            wr.close()
1062            f = wr.wait_closed()
1063            self.loop.run_until_complete(f)
1064            self.assertTrue(rd.at_eof())
1065            f = rd.read()
1066            data = self.loop.run_until_complete(f)
1067            self.assertEqual(data, b'')
1068
1069        self.assertEqual(messages, [])
1070
1071
1072if __name__ == '__main__':
1073    unittest.main()
1074