1"""Tests for sendfile functionality.""" 2 3import asyncio 4import errno 5import os 6import socket 7import sys 8import tempfile 9import unittest 10from asyncio import base_events 11from asyncio import constants 12from unittest import mock 13from test import support 14from test.support import os_helper 15from test.support import socket_helper 16from test.test_asyncio import utils as test_utils 17 18try: 19 import ssl 20except ImportError: 21 ssl = None 22 23 24def tearDownModule(): 25 asyncio.set_event_loop_policy(None) 26 27 28class MySendfileProto(asyncio.Protocol): 29 30 def __init__(self, loop=None, close_after=0): 31 self.transport = None 32 self.state = 'INITIAL' 33 self.nbytes = 0 34 if loop is not None: 35 self.connected = loop.create_future() 36 self.done = loop.create_future() 37 self.data = bytearray() 38 self.close_after = close_after 39 40 def _assert_state(self, *expected): 41 if self.state not in expected: 42 raise AssertionError(f'state: {self.state!r}, expected: {expected!r}') 43 44 def connection_made(self, transport): 45 self.transport = transport 46 self._assert_state('INITIAL') 47 self.state = 'CONNECTED' 48 if self.connected: 49 self.connected.set_result(None) 50 51 def eof_received(self): 52 self._assert_state('CONNECTED') 53 self.state = 'EOF' 54 55 def connection_lost(self, exc): 56 self._assert_state('CONNECTED', 'EOF') 57 self.state = 'CLOSED' 58 if self.done: 59 self.done.set_result(None) 60 61 def data_received(self, data): 62 self._assert_state('CONNECTED') 63 self.nbytes += len(data) 64 self.data.extend(data) 65 super().data_received(data) 66 if self.close_after and self.nbytes >= self.close_after: 67 self.transport.close() 68 69 70class MyProto(asyncio.Protocol): 71 72 def __init__(self, loop): 73 self.started = False 74 self.closed = False 75 self.data = bytearray() 76 self.fut = loop.create_future() 77 self.transport = None 78 79 def connection_made(self, transport): 80 self.started = True 81 self.transport = transport 82 83 def data_received(self, data): 84 self.data.extend(data) 85 86 def connection_lost(self, exc): 87 self.closed = True 88 self.fut.set_result(None) 89 90 async def wait_closed(self): 91 await self.fut 92 93 94class SendfileBase: 95 96 # 256 KiB plus small unaligned to buffer chunk 97 # Newer versions of Windows seems to have increased its internal 98 # buffer and tries to send as much of the data as it can as it 99 # has some form of buffering for this which is less than 256KiB 100 # on newer server versions and Windows 11. 101 # So DATA should be larger than 256 KiB to make this test reliable. 102 DATA = b"x" * (1024 * 256 + 1) 103 # Reduce socket buffer size to test on relative small data sets. 104 BUF_SIZE = 4 * 1024 # 4 KiB 105 106 def create_event_loop(self): 107 raise NotImplementedError 108 109 @classmethod 110 def setUpClass(cls): 111 with open(os_helper.TESTFN, 'wb') as fp: 112 fp.write(cls.DATA) 113 super().setUpClass() 114 115 @classmethod 116 def tearDownClass(cls): 117 os_helper.unlink(os_helper.TESTFN) 118 super().tearDownClass() 119 120 def setUp(self): 121 self.file = open(os_helper.TESTFN, 'rb') 122 self.addCleanup(self.file.close) 123 self.loop = self.create_event_loop() 124 self.set_event_loop(self.loop) 125 super().setUp() 126 127 def tearDown(self): 128 # just in case if we have transport close callbacks 129 if not self.loop.is_closed(): 130 test_utils.run_briefly(self.loop) 131 132 self.doCleanups() 133 support.gc_collect() 134 super().tearDown() 135 136 def run_loop(self, coro): 137 return self.loop.run_until_complete(coro) 138 139 140class SockSendfileMixin(SendfileBase): 141 142 @classmethod 143 def setUpClass(cls): 144 cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE 145 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 146 super().setUpClass() 147 148 @classmethod 149 def tearDownClass(cls): 150 constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize 151 super().tearDownClass() 152 153 def make_socket(self, cleanup=True): 154 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 155 sock.setblocking(False) 156 if cleanup: 157 self.addCleanup(sock.close) 158 return sock 159 160 def reduce_receive_buffer_size(self, sock): 161 # Reduce receive socket buffer size to test on relative 162 # small data sets. 163 sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) 164 165 def reduce_send_buffer_size(self, sock, transport=None): 166 # Reduce send socket buffer size to test on relative small data sets. 167 168 # On macOS, SO_SNDBUF is reset by connect(). So this method 169 # should be called after the socket is connected. 170 sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) 171 172 if transport is not None: 173 transport.set_write_buffer_limits(high=self.BUF_SIZE) 174 175 def prepare_socksendfile(self): 176 proto = MyProto(self.loop) 177 port = socket_helper.find_unused_port() 178 srv_sock = self.make_socket(cleanup=False) 179 srv_sock.bind((socket_helper.HOST, port)) 180 server = self.run_loop(self.loop.create_server( 181 lambda: proto, sock=srv_sock)) 182 self.reduce_receive_buffer_size(srv_sock) 183 184 sock = self.make_socket() 185 self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) 186 self.reduce_send_buffer_size(sock) 187 188 def cleanup(): 189 if proto.transport is not None: 190 # can be None if the task was cancelled before 191 # connection_made callback 192 proto.transport.close() 193 self.run_loop(proto.wait_closed()) 194 195 server.close() 196 self.run_loop(server.wait_closed()) 197 198 self.addCleanup(cleanup) 199 200 return sock, proto 201 202 def test_sock_sendfile_success(self): 203 sock, proto = self.prepare_socksendfile() 204 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 205 sock.close() 206 self.run_loop(proto.wait_closed()) 207 208 self.assertEqual(ret, len(self.DATA)) 209 self.assertEqual(proto.data, self.DATA) 210 self.assertEqual(self.file.tell(), len(self.DATA)) 211 212 def test_sock_sendfile_with_offset_and_count(self): 213 sock, proto = self.prepare_socksendfile() 214 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, 215 1000, 2000)) 216 sock.close() 217 self.run_loop(proto.wait_closed()) 218 219 self.assertEqual(proto.data, self.DATA[1000:3000]) 220 self.assertEqual(self.file.tell(), 3000) 221 self.assertEqual(ret, 2000) 222 223 def test_sock_sendfile_zero_size(self): 224 sock, proto = self.prepare_socksendfile() 225 with tempfile.TemporaryFile() as f: 226 ret = self.run_loop(self.loop.sock_sendfile(sock, f, 227 0, None)) 228 sock.close() 229 self.run_loop(proto.wait_closed()) 230 231 self.assertEqual(ret, 0) 232 self.assertEqual(self.file.tell(), 0) 233 234 def test_sock_sendfile_mix_with_regular_send(self): 235 buf = b"mix_regular_send" * (4 * 1024) # 64 KiB 236 sock, proto = self.prepare_socksendfile() 237 self.run_loop(self.loop.sock_sendall(sock, buf)) 238 ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) 239 self.run_loop(self.loop.sock_sendall(sock, buf)) 240 sock.close() 241 self.run_loop(proto.wait_closed()) 242 243 self.assertEqual(ret, len(self.DATA)) 244 expected = buf + self.DATA + buf 245 self.assertEqual(proto.data, expected) 246 self.assertEqual(self.file.tell(), len(self.DATA)) 247 248 249class SendfileMixin(SendfileBase): 250 251 # Note: sendfile via SSL transport is equal to sendfile fallback 252 253 def prepare_sendfile(self, *, is_ssl=False, close_after=0): 254 port = socket_helper.find_unused_port() 255 srv_proto = MySendfileProto(loop=self.loop, 256 close_after=close_after) 257 if is_ssl: 258 if not ssl: 259 self.skipTest("No ssl module") 260 srv_ctx = test_utils.simple_server_sslcontext() 261 cli_ctx = test_utils.simple_client_sslcontext() 262 else: 263 srv_ctx = None 264 cli_ctx = None 265 srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 266 srv_sock.bind((socket_helper.HOST, port)) 267 server = self.run_loop(self.loop.create_server( 268 lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) 269 self.reduce_receive_buffer_size(srv_sock) 270 271 if is_ssl: 272 server_hostname = socket_helper.HOST 273 else: 274 server_hostname = None 275 cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 276 cli_sock.connect((socket_helper.HOST, port)) 277 278 cli_proto = MySendfileProto(loop=self.loop) 279 tr, pr = self.run_loop(self.loop.create_connection( 280 lambda: cli_proto, sock=cli_sock, 281 ssl=cli_ctx, server_hostname=server_hostname)) 282 self.reduce_send_buffer_size(cli_sock, transport=tr) 283 284 def cleanup(): 285 srv_proto.transport.close() 286 cli_proto.transport.close() 287 self.run_loop(srv_proto.done) 288 self.run_loop(cli_proto.done) 289 290 server.close() 291 self.run_loop(server.wait_closed()) 292 293 self.addCleanup(cleanup) 294 return srv_proto, cli_proto 295 296 @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") 297 def test_sendfile_not_supported(self): 298 tr, pr = self.run_loop( 299 self.loop.create_datagram_endpoint( 300 asyncio.DatagramProtocol, 301 family=socket.AF_INET)) 302 try: 303 with self.assertRaisesRegex(RuntimeError, "not supported"): 304 self.run_loop( 305 self.loop.sendfile(tr, self.file)) 306 self.assertEqual(0, self.file.tell()) 307 finally: 308 # don't use self.addCleanup because it produces resource warning 309 tr.close() 310 311 def test_sendfile(self): 312 srv_proto, cli_proto = self.prepare_sendfile() 313 ret = self.run_loop( 314 self.loop.sendfile(cli_proto.transport, self.file)) 315 cli_proto.transport.close() 316 self.run_loop(srv_proto.done) 317 self.assertEqual(ret, len(self.DATA)) 318 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 319 self.assertEqual(srv_proto.data, self.DATA) 320 self.assertEqual(self.file.tell(), len(self.DATA)) 321 322 def test_sendfile_force_fallback(self): 323 srv_proto, cli_proto = self.prepare_sendfile() 324 325 def sendfile_native(transp, file, offset, count): 326 # to raise SendfileNotAvailableError 327 return base_events.BaseEventLoop._sendfile_native( 328 self.loop, transp, file, offset, count) 329 330 self.loop._sendfile_native = sendfile_native 331 332 ret = self.run_loop( 333 self.loop.sendfile(cli_proto.transport, self.file)) 334 cli_proto.transport.close() 335 self.run_loop(srv_proto.done) 336 self.assertEqual(ret, len(self.DATA)) 337 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 338 self.assertEqual(srv_proto.data, self.DATA) 339 self.assertEqual(self.file.tell(), len(self.DATA)) 340 341 def test_sendfile_force_unsupported_native(self): 342 if sys.platform == 'win32': 343 if isinstance(self.loop, asyncio.ProactorEventLoop): 344 self.skipTest("Fails on proactor event loop") 345 srv_proto, cli_proto = self.prepare_sendfile() 346 347 def sendfile_native(transp, file, offset, count): 348 # to raise SendfileNotAvailableError 349 return base_events.BaseEventLoop._sendfile_native( 350 self.loop, transp, file, offset, count) 351 352 self.loop._sendfile_native = sendfile_native 353 354 with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, 355 "not supported"): 356 self.run_loop( 357 self.loop.sendfile(cli_proto.transport, self.file, 358 fallback=False)) 359 360 cli_proto.transport.close() 361 self.run_loop(srv_proto.done) 362 self.assertEqual(srv_proto.nbytes, 0) 363 self.assertEqual(self.file.tell(), 0) 364 365 def test_sendfile_ssl(self): 366 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 367 ret = self.run_loop( 368 self.loop.sendfile(cli_proto.transport, self.file)) 369 cli_proto.transport.close() 370 self.run_loop(srv_proto.done) 371 self.assertEqual(ret, len(self.DATA)) 372 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 373 self.assertEqual(srv_proto.data, self.DATA) 374 self.assertEqual(self.file.tell(), len(self.DATA)) 375 376 def test_sendfile_for_closing_transp(self): 377 srv_proto, cli_proto = self.prepare_sendfile() 378 cli_proto.transport.close() 379 with self.assertRaisesRegex(RuntimeError, "is closing"): 380 self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) 381 self.run_loop(srv_proto.done) 382 self.assertEqual(srv_proto.nbytes, 0) 383 self.assertEqual(self.file.tell(), 0) 384 385 def test_sendfile_pre_and_post_data(self): 386 srv_proto, cli_proto = self.prepare_sendfile() 387 PREFIX = b'PREFIX__' * 1024 # 8 KiB 388 SUFFIX = b'--SUFFIX' * 1024 # 8 KiB 389 cli_proto.transport.write(PREFIX) 390 ret = self.run_loop( 391 self.loop.sendfile(cli_proto.transport, self.file)) 392 cli_proto.transport.write(SUFFIX) 393 cli_proto.transport.close() 394 self.run_loop(srv_proto.done) 395 self.assertEqual(ret, len(self.DATA)) 396 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 397 self.assertEqual(self.file.tell(), len(self.DATA)) 398 399 def test_sendfile_ssl_pre_and_post_data(self): 400 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 401 PREFIX = b'zxcvbnm' * 1024 402 SUFFIX = b'0987654321' * 1024 403 cli_proto.transport.write(PREFIX) 404 ret = self.run_loop( 405 self.loop.sendfile(cli_proto.transport, self.file)) 406 cli_proto.transport.write(SUFFIX) 407 cli_proto.transport.close() 408 self.run_loop(srv_proto.done) 409 self.assertEqual(ret, len(self.DATA)) 410 self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) 411 self.assertEqual(self.file.tell(), len(self.DATA)) 412 413 def test_sendfile_partial(self): 414 srv_proto, cli_proto = self.prepare_sendfile() 415 ret = self.run_loop( 416 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 417 cli_proto.transport.close() 418 self.run_loop(srv_proto.done) 419 self.assertEqual(ret, 100) 420 self.assertEqual(srv_proto.nbytes, 100) 421 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 422 self.assertEqual(self.file.tell(), 1100) 423 424 def test_sendfile_ssl_partial(self): 425 srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) 426 ret = self.run_loop( 427 self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) 428 cli_proto.transport.close() 429 self.run_loop(srv_proto.done) 430 self.assertEqual(ret, 100) 431 self.assertEqual(srv_proto.nbytes, 100) 432 self.assertEqual(srv_proto.data, self.DATA[1000:1100]) 433 self.assertEqual(self.file.tell(), 1100) 434 435 def test_sendfile_close_peer_after_receiving(self): 436 srv_proto, cli_proto = self.prepare_sendfile( 437 close_after=len(self.DATA)) 438 ret = self.run_loop( 439 self.loop.sendfile(cli_proto.transport, self.file)) 440 cli_proto.transport.close() 441 self.run_loop(srv_proto.done) 442 self.assertEqual(ret, len(self.DATA)) 443 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 444 self.assertEqual(srv_proto.data, self.DATA) 445 self.assertEqual(self.file.tell(), len(self.DATA)) 446 447 def test_sendfile_ssl_close_peer_after_receiving(self): 448 srv_proto, cli_proto = self.prepare_sendfile( 449 is_ssl=True, close_after=len(self.DATA)) 450 ret = self.run_loop( 451 self.loop.sendfile(cli_proto.transport, self.file)) 452 self.run_loop(srv_proto.done) 453 self.assertEqual(ret, len(self.DATA)) 454 self.assertEqual(srv_proto.nbytes, len(self.DATA)) 455 self.assertEqual(srv_proto.data, self.DATA) 456 self.assertEqual(self.file.tell(), len(self.DATA)) 457 458 # On Solaris, lowering SO_RCVBUF on a TCP connection after it has been 459 # established has no effect. Due to its age, this bug affects both Oracle 460 # Solaris as well as all other OpenSolaris forks (unless they fixed it 461 # themselves). 462 @unittest.skipIf(sys.platform.startswith('sunos'), 463 "Doesn't work on Solaris") 464 def test_sendfile_close_peer_in_the_middle_of_receiving(self): 465 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 466 with self.assertRaises(ConnectionError): 467 self.run_loop( 468 self.loop.sendfile(cli_proto.transport, self.file)) 469 self.run_loop(srv_proto.done) 470 471 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 472 srv_proto.nbytes) 473 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 474 self.file.tell()) 475 self.assertTrue(cli_proto.transport.is_closing()) 476 477 def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): 478 479 def sendfile_native(transp, file, offset, count): 480 # to raise SendfileNotAvailableError 481 return base_events.BaseEventLoop._sendfile_native( 482 self.loop, transp, file, offset, count) 483 484 self.loop._sendfile_native = sendfile_native 485 486 srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) 487 with self.assertRaises(ConnectionError): 488 try: 489 self.run_loop( 490 self.loop.sendfile(cli_proto.transport, self.file)) 491 except OSError as e: 492 # macOS may raise OSError of EPROTOTYPE when writing to a 493 # socket that is in the process of closing down. 494 if e.errno == errno.EPROTOTYPE and sys.platform == "darwin": 495 raise ConnectionError 496 else: 497 raise 498 499 self.run_loop(srv_proto.done) 500 501 self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), 502 srv_proto.nbytes) 503 self.assertTrue(1024 <= self.file.tell() < len(self.DATA), 504 self.file.tell()) 505 506 @unittest.skipIf(not hasattr(os, 'sendfile'), 507 "Don't have native sendfile support") 508 def test_sendfile_prevents_bare_write(self): 509 srv_proto, cli_proto = self.prepare_sendfile() 510 fut = self.loop.create_future() 511 512 async def coro(): 513 fut.set_result(None) 514 return await self.loop.sendfile(cli_proto.transport, self.file) 515 516 t = self.loop.create_task(coro()) 517 self.run_loop(fut) 518 with self.assertRaisesRegex(RuntimeError, 519 "sendfile is in progress"): 520 cli_proto.transport.write(b'data') 521 ret = self.run_loop(t) 522 self.assertEqual(ret, len(self.DATA)) 523 524 def test_sendfile_no_fallback_for_fallback_transport(self): 525 transport = mock.Mock() 526 transport.is_closing.side_effect = lambda: False 527 transport._sendfile_compatible = constants._SendfileMode.FALLBACK 528 with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): 529 self.loop.run_until_complete( 530 self.loop.sendfile(transport, None, fallback=False)) 531 532 533class SendfileTestsBase(SendfileMixin, SockSendfileMixin): 534 pass 535 536 537if sys.platform == 'win32': 538 539 class SelectEventLoopTests(SendfileTestsBase, 540 test_utils.TestCase): 541 542 def create_event_loop(self): 543 return asyncio.SelectorEventLoop() 544 545 class ProactorEventLoopTests(SendfileTestsBase, 546 test_utils.TestCase): 547 548 def create_event_loop(self): 549 return asyncio.ProactorEventLoop() 550 551else: 552 import selectors 553 554 if hasattr(selectors, 'KqueueSelector'): 555 class KqueueEventLoopTests(SendfileTestsBase, 556 test_utils.TestCase): 557 558 def create_event_loop(self): 559 return asyncio.SelectorEventLoop( 560 selectors.KqueueSelector()) 561 562 if hasattr(selectors, 'EpollSelector'): 563 class EPollEventLoopTests(SendfileTestsBase, 564 test_utils.TestCase): 565 566 def create_event_loop(self): 567 return asyncio.SelectorEventLoop(selectors.EpollSelector()) 568 569 if hasattr(selectors, 'PollSelector'): 570 class PollEventLoopTests(SendfileTestsBase, 571 test_utils.TestCase): 572 573 def create_event_loop(self): 574 return asyncio.SelectorEventLoop(selectors.PollSelector()) 575 576 # Should always exist. 577 class SelectEventLoopTests(SendfileTestsBase, 578 test_utils.TestCase): 579 580 def create_event_loop(self): 581 return asyncio.SelectorEventLoop(selectors.SelectSelector()) 582 583 584if __name__ == '__main__': 585 unittest.main() 586