1"""Tests for sendfile functionality."""
2
3import asyncio
4import errno
5import os
6import socket
7import sys
8import tempfile
9import unittest
10from asyncio import base_events
11from asyncio import constants
12from unittest import mock
13from test import support
14from test.support import os_helper
15from test.support import socket_helper
16from test.test_asyncio import utils as test_utils
17
18try:
19    import ssl
20except ImportError:
21    ssl = None
22
23
24def tearDownModule():
25    asyncio.set_event_loop_policy(None)
26
27
28class MySendfileProto(asyncio.Protocol):
29
30    def __init__(self, loop=None, close_after=0):
31        self.transport = None
32        self.state = 'INITIAL'
33        self.nbytes = 0
34        if loop is not None:
35            self.connected = loop.create_future()
36            self.done = loop.create_future()
37        self.data = bytearray()
38        self.close_after = close_after
39
40    def _assert_state(self, *expected):
41        if self.state not in expected:
42            raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
43
44    def connection_made(self, transport):
45        self.transport = transport
46        self._assert_state('INITIAL')
47        self.state = 'CONNECTED'
48        if self.connected:
49            self.connected.set_result(None)
50
51    def eof_received(self):
52        self._assert_state('CONNECTED')
53        self.state = 'EOF'
54
55    def connection_lost(self, exc):
56        self._assert_state('CONNECTED', 'EOF')
57        self.state = 'CLOSED'
58        if self.done:
59            self.done.set_result(None)
60
61    def data_received(self, data):
62        self._assert_state('CONNECTED')
63        self.nbytes += len(data)
64        self.data.extend(data)
65        super().data_received(data)
66        if self.close_after and self.nbytes >= self.close_after:
67            self.transport.close()
68
69
70class MyProto(asyncio.Protocol):
71
72    def __init__(self, loop):
73        self.started = False
74        self.closed = False
75        self.data = bytearray()
76        self.fut = loop.create_future()
77        self.transport = None
78
79    def connection_made(self, transport):
80        self.started = True
81        self.transport = transport
82
83    def data_received(self, data):
84        self.data.extend(data)
85
86    def connection_lost(self, exc):
87        self.closed = True
88        self.fut.set_result(None)
89
90    async def wait_closed(self):
91        await self.fut
92
93
94class SendfileBase:
95
96    # 256 KiB plus small unaligned to buffer chunk
97    # Newer versions of Windows seems to have increased its internal
98    # buffer and tries to send as much of the data as it can as it
99    # has some form of buffering for this which is less than 256KiB
100    # on newer server versions and Windows 11.
101    # So DATA should be larger than 256 KiB to make this test reliable.
102    DATA = b"x" * (1024 * 256 + 1)
103    # Reduce socket buffer size to test on relative small data sets.
104    BUF_SIZE = 4 * 1024   # 4 KiB
105
106    def create_event_loop(self):
107        raise NotImplementedError
108
109    @classmethod
110    def setUpClass(cls):
111        with open(os_helper.TESTFN, 'wb') as fp:
112            fp.write(cls.DATA)
113        super().setUpClass()
114
115    @classmethod
116    def tearDownClass(cls):
117        os_helper.unlink(os_helper.TESTFN)
118        super().tearDownClass()
119
120    def setUp(self):
121        self.file = open(os_helper.TESTFN, 'rb')
122        self.addCleanup(self.file.close)
123        self.loop = self.create_event_loop()
124        self.set_event_loop(self.loop)
125        super().setUp()
126
127    def tearDown(self):
128        # just in case if we have transport close callbacks
129        if not self.loop.is_closed():
130            test_utils.run_briefly(self.loop)
131
132        self.doCleanups()
133        support.gc_collect()
134        super().tearDown()
135
136    def run_loop(self, coro):
137        return self.loop.run_until_complete(coro)
138
139
140class SockSendfileMixin(SendfileBase):
141
142    @classmethod
143    def setUpClass(cls):
144        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
145        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
146        super().setUpClass()
147
148    @classmethod
149    def tearDownClass(cls):
150        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
151        super().tearDownClass()
152
153    def make_socket(self, cleanup=True):
154        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
155        sock.setblocking(False)
156        if cleanup:
157            self.addCleanup(sock.close)
158        return sock
159
160    def reduce_receive_buffer_size(self, sock):
161        # Reduce receive socket buffer size to test on relative
162        # small data sets.
163        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
164
165    def reduce_send_buffer_size(self, sock, transport=None):
166        # Reduce send socket buffer size to test on relative small data sets.
167
168        # On macOS, SO_SNDBUF is reset by connect(). So this method
169        # should be called after the socket is connected.
170        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
171
172        if transport is not None:
173            transport.set_write_buffer_limits(high=self.BUF_SIZE)
174
175    def prepare_socksendfile(self):
176        proto = MyProto(self.loop)
177        port = socket_helper.find_unused_port()
178        srv_sock = self.make_socket(cleanup=False)
179        srv_sock.bind((socket_helper.HOST, port))
180        server = self.run_loop(self.loop.create_server(
181            lambda: proto, sock=srv_sock))
182        self.reduce_receive_buffer_size(srv_sock)
183
184        sock = self.make_socket()
185        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
186        self.reduce_send_buffer_size(sock)
187
188        def cleanup():
189            if proto.transport is not None:
190                # can be None if the task was cancelled before
191                # connection_made callback
192                proto.transport.close()
193                self.run_loop(proto.wait_closed())
194
195            server.close()
196            self.run_loop(server.wait_closed())
197
198        self.addCleanup(cleanup)
199
200        return sock, proto
201
202    def test_sock_sendfile_success(self):
203        sock, proto = self.prepare_socksendfile()
204        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
205        sock.close()
206        self.run_loop(proto.wait_closed())
207
208        self.assertEqual(ret, len(self.DATA))
209        self.assertEqual(proto.data, self.DATA)
210        self.assertEqual(self.file.tell(), len(self.DATA))
211
212    def test_sock_sendfile_with_offset_and_count(self):
213        sock, proto = self.prepare_socksendfile()
214        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
215                                                    1000, 2000))
216        sock.close()
217        self.run_loop(proto.wait_closed())
218
219        self.assertEqual(proto.data, self.DATA[1000:3000])
220        self.assertEqual(self.file.tell(), 3000)
221        self.assertEqual(ret, 2000)
222
223    def test_sock_sendfile_zero_size(self):
224        sock, proto = self.prepare_socksendfile()
225        with tempfile.TemporaryFile() as f:
226            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
227                                                        0, None))
228        sock.close()
229        self.run_loop(proto.wait_closed())
230
231        self.assertEqual(ret, 0)
232        self.assertEqual(self.file.tell(), 0)
233
234    def test_sock_sendfile_mix_with_regular_send(self):
235        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
236        sock, proto = self.prepare_socksendfile()
237        self.run_loop(self.loop.sock_sendall(sock, buf))
238        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
239        self.run_loop(self.loop.sock_sendall(sock, buf))
240        sock.close()
241        self.run_loop(proto.wait_closed())
242
243        self.assertEqual(ret, len(self.DATA))
244        expected = buf + self.DATA + buf
245        self.assertEqual(proto.data, expected)
246        self.assertEqual(self.file.tell(), len(self.DATA))
247
248
249class SendfileMixin(SendfileBase):
250
251    # Note: sendfile via SSL transport is equal to sendfile fallback
252
253    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
254        port = socket_helper.find_unused_port()
255        srv_proto = MySendfileProto(loop=self.loop,
256                                    close_after=close_after)
257        if is_ssl:
258            if not ssl:
259                self.skipTest("No ssl module")
260            srv_ctx = test_utils.simple_server_sslcontext()
261            cli_ctx = test_utils.simple_client_sslcontext()
262        else:
263            srv_ctx = None
264            cli_ctx = None
265        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
266        srv_sock.bind((socket_helper.HOST, port))
267        server = self.run_loop(self.loop.create_server(
268            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
269        self.reduce_receive_buffer_size(srv_sock)
270
271        if is_ssl:
272            server_hostname = socket_helper.HOST
273        else:
274            server_hostname = None
275        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
276        cli_sock.connect((socket_helper.HOST, port))
277
278        cli_proto = MySendfileProto(loop=self.loop)
279        tr, pr = self.run_loop(self.loop.create_connection(
280            lambda: cli_proto, sock=cli_sock,
281            ssl=cli_ctx, server_hostname=server_hostname))
282        self.reduce_send_buffer_size(cli_sock, transport=tr)
283
284        def cleanup():
285            srv_proto.transport.close()
286            cli_proto.transport.close()
287            self.run_loop(srv_proto.done)
288            self.run_loop(cli_proto.done)
289
290            server.close()
291            self.run_loop(server.wait_closed())
292
293        self.addCleanup(cleanup)
294        return srv_proto, cli_proto
295
296    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
297    def test_sendfile_not_supported(self):
298        tr, pr = self.run_loop(
299            self.loop.create_datagram_endpoint(
300                asyncio.DatagramProtocol,
301                family=socket.AF_INET))
302        try:
303            with self.assertRaisesRegex(RuntimeError, "not supported"):
304                self.run_loop(
305                    self.loop.sendfile(tr, self.file))
306            self.assertEqual(0, self.file.tell())
307        finally:
308            # don't use self.addCleanup because it produces resource warning
309            tr.close()
310
311    def test_sendfile(self):
312        srv_proto, cli_proto = self.prepare_sendfile()
313        ret = self.run_loop(
314            self.loop.sendfile(cli_proto.transport, self.file))
315        cli_proto.transport.close()
316        self.run_loop(srv_proto.done)
317        self.assertEqual(ret, len(self.DATA))
318        self.assertEqual(srv_proto.nbytes, len(self.DATA))
319        self.assertEqual(srv_proto.data, self.DATA)
320        self.assertEqual(self.file.tell(), len(self.DATA))
321
322    def test_sendfile_force_fallback(self):
323        srv_proto, cli_proto = self.prepare_sendfile()
324
325        def sendfile_native(transp, file, offset, count):
326            # to raise SendfileNotAvailableError
327            return base_events.BaseEventLoop._sendfile_native(
328                self.loop, transp, file, offset, count)
329
330        self.loop._sendfile_native = sendfile_native
331
332        ret = self.run_loop(
333            self.loop.sendfile(cli_proto.transport, self.file))
334        cli_proto.transport.close()
335        self.run_loop(srv_proto.done)
336        self.assertEqual(ret, len(self.DATA))
337        self.assertEqual(srv_proto.nbytes, len(self.DATA))
338        self.assertEqual(srv_proto.data, self.DATA)
339        self.assertEqual(self.file.tell(), len(self.DATA))
340
341    def test_sendfile_force_unsupported_native(self):
342        if sys.platform == 'win32':
343            if isinstance(self.loop, asyncio.ProactorEventLoop):
344                self.skipTest("Fails on proactor event loop")
345        srv_proto, cli_proto = self.prepare_sendfile()
346
347        def sendfile_native(transp, file, offset, count):
348            # to raise SendfileNotAvailableError
349            return base_events.BaseEventLoop._sendfile_native(
350                self.loop, transp, file, offset, count)
351
352        self.loop._sendfile_native = sendfile_native
353
354        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
355                                    "not supported"):
356            self.run_loop(
357                self.loop.sendfile(cli_proto.transport, self.file,
358                                   fallback=False))
359
360        cli_proto.transport.close()
361        self.run_loop(srv_proto.done)
362        self.assertEqual(srv_proto.nbytes, 0)
363        self.assertEqual(self.file.tell(), 0)
364
365    def test_sendfile_ssl(self):
366        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
367        ret = self.run_loop(
368            self.loop.sendfile(cli_proto.transport, self.file))
369        cli_proto.transport.close()
370        self.run_loop(srv_proto.done)
371        self.assertEqual(ret, len(self.DATA))
372        self.assertEqual(srv_proto.nbytes, len(self.DATA))
373        self.assertEqual(srv_proto.data, self.DATA)
374        self.assertEqual(self.file.tell(), len(self.DATA))
375
376    def test_sendfile_for_closing_transp(self):
377        srv_proto, cli_proto = self.prepare_sendfile()
378        cli_proto.transport.close()
379        with self.assertRaisesRegex(RuntimeError, "is closing"):
380            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
381        self.run_loop(srv_proto.done)
382        self.assertEqual(srv_proto.nbytes, 0)
383        self.assertEqual(self.file.tell(), 0)
384
385    def test_sendfile_pre_and_post_data(self):
386        srv_proto, cli_proto = self.prepare_sendfile()
387        PREFIX = b'PREFIX__' * 1024  # 8 KiB
388        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
389        cli_proto.transport.write(PREFIX)
390        ret = self.run_loop(
391            self.loop.sendfile(cli_proto.transport, self.file))
392        cli_proto.transport.write(SUFFIX)
393        cli_proto.transport.close()
394        self.run_loop(srv_proto.done)
395        self.assertEqual(ret, len(self.DATA))
396        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
397        self.assertEqual(self.file.tell(), len(self.DATA))
398
399    def test_sendfile_ssl_pre_and_post_data(self):
400        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
401        PREFIX = b'zxcvbnm' * 1024
402        SUFFIX = b'0987654321' * 1024
403        cli_proto.transport.write(PREFIX)
404        ret = self.run_loop(
405            self.loop.sendfile(cli_proto.transport, self.file))
406        cli_proto.transport.write(SUFFIX)
407        cli_proto.transport.close()
408        self.run_loop(srv_proto.done)
409        self.assertEqual(ret, len(self.DATA))
410        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
411        self.assertEqual(self.file.tell(), len(self.DATA))
412
413    def test_sendfile_partial(self):
414        srv_proto, cli_proto = self.prepare_sendfile()
415        ret = self.run_loop(
416            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
417        cli_proto.transport.close()
418        self.run_loop(srv_proto.done)
419        self.assertEqual(ret, 100)
420        self.assertEqual(srv_proto.nbytes, 100)
421        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
422        self.assertEqual(self.file.tell(), 1100)
423
424    def test_sendfile_ssl_partial(self):
425        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
426        ret = self.run_loop(
427            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
428        cli_proto.transport.close()
429        self.run_loop(srv_proto.done)
430        self.assertEqual(ret, 100)
431        self.assertEqual(srv_proto.nbytes, 100)
432        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
433        self.assertEqual(self.file.tell(), 1100)
434
435    def test_sendfile_close_peer_after_receiving(self):
436        srv_proto, cli_proto = self.prepare_sendfile(
437            close_after=len(self.DATA))
438        ret = self.run_loop(
439            self.loop.sendfile(cli_proto.transport, self.file))
440        cli_proto.transport.close()
441        self.run_loop(srv_proto.done)
442        self.assertEqual(ret, len(self.DATA))
443        self.assertEqual(srv_proto.nbytes, len(self.DATA))
444        self.assertEqual(srv_proto.data, self.DATA)
445        self.assertEqual(self.file.tell(), len(self.DATA))
446
447    def test_sendfile_ssl_close_peer_after_receiving(self):
448        srv_proto, cli_proto = self.prepare_sendfile(
449            is_ssl=True, close_after=len(self.DATA))
450        ret = self.run_loop(
451            self.loop.sendfile(cli_proto.transport, self.file))
452        self.run_loop(srv_proto.done)
453        self.assertEqual(ret, len(self.DATA))
454        self.assertEqual(srv_proto.nbytes, len(self.DATA))
455        self.assertEqual(srv_proto.data, self.DATA)
456        self.assertEqual(self.file.tell(), len(self.DATA))
457
458    # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
459    # established has no effect. Due to its age, this bug affects both Oracle
460    # Solaris as well as all other OpenSolaris forks (unless they fixed it
461    # themselves).
462    @unittest.skipIf(sys.platform.startswith('sunos'),
463                     "Doesn't work on Solaris")
464    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
465        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
466        with self.assertRaises(ConnectionError):
467            self.run_loop(
468                self.loop.sendfile(cli_proto.transport, self.file))
469        self.run_loop(srv_proto.done)
470
471        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
472                        srv_proto.nbytes)
473        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
474                        self.file.tell())
475        self.assertTrue(cli_proto.transport.is_closing())
476
477    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
478
479        def sendfile_native(transp, file, offset, count):
480            # to raise SendfileNotAvailableError
481            return base_events.BaseEventLoop._sendfile_native(
482                self.loop, transp, file, offset, count)
483
484        self.loop._sendfile_native = sendfile_native
485
486        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
487        with self.assertRaises(ConnectionError):
488            try:
489                self.run_loop(
490                    self.loop.sendfile(cli_proto.transport, self.file))
491            except OSError as e:
492                # macOS may raise OSError of EPROTOTYPE when writing to a
493                # socket that is in the process of closing down.
494                if e.errno == errno.EPROTOTYPE and sys.platform == "darwin":
495                    raise ConnectionError
496                else:
497                    raise
498
499        self.run_loop(srv_proto.done)
500
501        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
502                        srv_proto.nbytes)
503        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
504                        self.file.tell())
505
506    @unittest.skipIf(not hasattr(os, 'sendfile'),
507                     "Don't have native sendfile support")
508    def test_sendfile_prevents_bare_write(self):
509        srv_proto, cli_proto = self.prepare_sendfile()
510        fut = self.loop.create_future()
511
512        async def coro():
513            fut.set_result(None)
514            return await self.loop.sendfile(cli_proto.transport, self.file)
515
516        t = self.loop.create_task(coro())
517        self.run_loop(fut)
518        with self.assertRaisesRegex(RuntimeError,
519                                    "sendfile is in progress"):
520            cli_proto.transport.write(b'data')
521        ret = self.run_loop(t)
522        self.assertEqual(ret, len(self.DATA))
523
524    def test_sendfile_no_fallback_for_fallback_transport(self):
525        transport = mock.Mock()
526        transport.is_closing.side_effect = lambda: False
527        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
528        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
529            self.loop.run_until_complete(
530                self.loop.sendfile(transport, None, fallback=False))
531
532
533class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
534    pass
535
536
537if sys.platform == 'win32':
538
539    class SelectEventLoopTests(SendfileTestsBase,
540                               test_utils.TestCase):
541
542        def create_event_loop(self):
543            return asyncio.SelectorEventLoop()
544
545    class ProactorEventLoopTests(SendfileTestsBase,
546                                 test_utils.TestCase):
547
548        def create_event_loop(self):
549            return asyncio.ProactorEventLoop()
550
551else:
552    import selectors
553
554    if hasattr(selectors, 'KqueueSelector'):
555        class KqueueEventLoopTests(SendfileTestsBase,
556                                   test_utils.TestCase):
557
558            def create_event_loop(self):
559                return asyncio.SelectorEventLoop(
560                    selectors.KqueueSelector())
561
562    if hasattr(selectors, 'EpollSelector'):
563        class EPollEventLoopTests(SendfileTestsBase,
564                                  test_utils.TestCase):
565
566            def create_event_loop(self):
567                return asyncio.SelectorEventLoop(selectors.EpollSelector())
568
569    if hasattr(selectors, 'PollSelector'):
570        class PollEventLoopTests(SendfileTestsBase,
571                                 test_utils.TestCase):
572
573            def create_event_loop(self):
574                return asyncio.SelectorEventLoop(selectors.PollSelector())
575
576    # Should always exist.
577    class SelectEventLoopTests(SendfileTestsBase,
578                               test_utils.TestCase):
579
580        def create_event_loop(self):
581            return asyncio.SelectorEventLoop(selectors.SelectSelector())
582
583
584if __name__ == '__main__':
585    unittest.main()
586