1""" 2Test suite for socketserver. 3""" 4 5import contextlib 6import io 7import os 8import select 9import signal 10import socket 11import tempfile 12import threading 13import unittest 14import socketserver 15 16import test.support 17from test.support import reap_children, verbose 18from test.support import os_helper 19from test.support import socket_helper 20from test.support import threading_helper 21 22 23test.support.requires("network") 24test.support.requires_working_socket(module=True) 25 26 27TEST_STR = b"hello world\n" 28HOST = socket_helper.HOST 29 30HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 31requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 32 'requires Unix sockets') 33HAVE_FORKING = test.support.has_fork_support 34requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 35 36def signal_alarm(n): 37 """Call signal.alarm when it exists (i.e. not on Windows).""" 38 if hasattr(signal, 'alarm'): 39 signal.alarm(n) 40 41# Remember real select() to avoid interferences with mocking 42_real_select = select.select 43 44def receive(sock, n, timeout=test.support.SHORT_TIMEOUT): 45 r, w, x = _real_select([sock], [], [], timeout) 46 if sock in r: 47 return sock.recv(n) 48 else: 49 raise RuntimeError("timed out on %r" % (sock,)) 50 51if HAVE_UNIX_SOCKETS and HAVE_FORKING: 52 class ForkingUnixStreamServer(socketserver.ForkingMixIn, 53 socketserver.UnixStreamServer): 54 pass 55 56 class ForkingUnixDatagramServer(socketserver.ForkingMixIn, 57 socketserver.UnixDatagramServer): 58 pass 59 60 61@contextlib.contextmanager 62def simple_subprocess(testcase): 63 """Tests that a custom child process is not waited on (Issue 1540386)""" 64 pid = os.fork() 65 if pid == 0: 66 # Don't raise an exception; it would be caught by the test harness. 67 os._exit(72) 68 try: 69 yield None 70 except: 71 raise 72 finally: 73 test.support.wait_process(pid, exitcode=72) 74 75 76class SocketServerTest(unittest.TestCase): 77 """Test all socket servers.""" 78 79 def setUp(self): 80 signal_alarm(60) # Kill deadlocks after 60 seconds. 81 self.port_seed = 0 82 self.test_files = [] 83 84 def tearDown(self): 85 signal_alarm(0) # Didn't deadlock. 86 reap_children() 87 88 for fn in self.test_files: 89 try: 90 os.remove(fn) 91 except OSError: 92 pass 93 self.test_files[:] = [] 94 95 def pickaddr(self, proto): 96 if proto == socket.AF_INET: 97 return (HOST, 0) 98 else: 99 # XXX: We need a way to tell AF_UNIX to pick its own name 100 # like AF_INET provides port==0. 101 dir = None 102 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 103 self.test_files.append(fn) 104 return fn 105 106 def make_server(self, addr, svrcls, hdlrbase): 107 class MyServer(svrcls): 108 def handle_error(self, request, client_address): 109 self.close_request(request) 110 raise 111 112 class MyHandler(hdlrbase): 113 def handle(self): 114 line = self.rfile.readline() 115 self.wfile.write(line) 116 117 if verbose: print("creating server") 118 try: 119 server = MyServer(addr, MyHandler) 120 except PermissionError as e: 121 # Issue 29184: cannot bind() a Unix socket on Android. 122 self.skipTest('Cannot create server (%s, %s): %s' % 123 (svrcls, addr, e)) 124 self.assertEqual(server.server_address, server.socket.getsockname()) 125 return server 126 127 @threading_helper.reap_threads 128 def run_server(self, svrcls, hdlrbase, testfunc): 129 server = self.make_server(self.pickaddr(svrcls.address_family), 130 svrcls, hdlrbase) 131 # We had the OS pick a port, so pull the real address out of 132 # the server. 133 addr = server.server_address 134 if verbose: 135 print("ADDR =", addr) 136 print("CLASS =", svrcls) 137 138 t = threading.Thread( 139 name='%s serving' % svrcls, 140 target=server.serve_forever, 141 # Short poll interval to make the test finish quickly. 142 # Time between requests is short enough that we won't wake 143 # up spuriously too many times. 144 kwargs={'poll_interval':0.01}) 145 t.daemon = True # In case this function raises. 146 t.start() 147 if verbose: print("server running") 148 for i in range(3): 149 if verbose: print("test client", i) 150 testfunc(svrcls.address_family, addr) 151 if verbose: print("waiting for server") 152 server.shutdown() 153 t.join() 154 server.server_close() 155 self.assertEqual(-1, server.socket.fileno()) 156 if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn): 157 # bpo-31151: Check that ForkingMixIn.server_close() waits until 158 # all children completed 159 self.assertFalse(server.active_children) 160 if verbose: print("done") 161 162 def stream_examine(self, proto, addr): 163 with socket.socket(proto, socket.SOCK_STREAM) as s: 164 s.connect(addr) 165 s.sendall(TEST_STR) 166 buf = data = receive(s, 100) 167 while data and b'\n' not in buf: 168 data = receive(s, 100) 169 buf += data 170 self.assertEqual(buf, TEST_STR) 171 172 def dgram_examine(self, proto, addr): 173 with socket.socket(proto, socket.SOCK_DGRAM) as s: 174 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 175 s.bind(self.pickaddr(proto)) 176 s.sendto(TEST_STR, addr) 177 buf = data = receive(s, 100) 178 while data and b'\n' not in buf: 179 data = receive(s, 100) 180 buf += data 181 self.assertEqual(buf, TEST_STR) 182 183 def test_TCPServer(self): 184 self.run_server(socketserver.TCPServer, 185 socketserver.StreamRequestHandler, 186 self.stream_examine) 187 188 def test_ThreadingTCPServer(self): 189 self.run_server(socketserver.ThreadingTCPServer, 190 socketserver.StreamRequestHandler, 191 self.stream_examine) 192 193 @requires_forking 194 def test_ForkingTCPServer(self): 195 with simple_subprocess(self): 196 self.run_server(socketserver.ForkingTCPServer, 197 socketserver.StreamRequestHandler, 198 self.stream_examine) 199 200 @requires_unix_sockets 201 def test_UnixStreamServer(self): 202 self.run_server(socketserver.UnixStreamServer, 203 socketserver.StreamRequestHandler, 204 self.stream_examine) 205 206 @requires_unix_sockets 207 def test_ThreadingUnixStreamServer(self): 208 self.run_server(socketserver.ThreadingUnixStreamServer, 209 socketserver.StreamRequestHandler, 210 self.stream_examine) 211 212 @requires_unix_sockets 213 @requires_forking 214 def test_ForkingUnixStreamServer(self): 215 with simple_subprocess(self): 216 self.run_server(ForkingUnixStreamServer, 217 socketserver.StreamRequestHandler, 218 self.stream_examine) 219 220 def test_UDPServer(self): 221 self.run_server(socketserver.UDPServer, 222 socketserver.DatagramRequestHandler, 223 self.dgram_examine) 224 225 def test_ThreadingUDPServer(self): 226 self.run_server(socketserver.ThreadingUDPServer, 227 socketserver.DatagramRequestHandler, 228 self.dgram_examine) 229 230 @requires_forking 231 def test_ForkingUDPServer(self): 232 with simple_subprocess(self): 233 self.run_server(socketserver.ForkingUDPServer, 234 socketserver.DatagramRequestHandler, 235 self.dgram_examine) 236 237 @requires_unix_sockets 238 def test_UnixDatagramServer(self): 239 self.run_server(socketserver.UnixDatagramServer, 240 socketserver.DatagramRequestHandler, 241 self.dgram_examine) 242 243 @requires_unix_sockets 244 def test_ThreadingUnixDatagramServer(self): 245 self.run_server(socketserver.ThreadingUnixDatagramServer, 246 socketserver.DatagramRequestHandler, 247 self.dgram_examine) 248 249 @requires_unix_sockets 250 @requires_forking 251 def test_ForkingUnixDatagramServer(self): 252 self.run_server(ForkingUnixDatagramServer, 253 socketserver.DatagramRequestHandler, 254 self.dgram_examine) 255 256 @threading_helper.reap_threads 257 def test_shutdown(self): 258 # Issue #2302: shutdown() should always succeed in making an 259 # other thread leave serve_forever(). 260 class MyServer(socketserver.TCPServer): 261 pass 262 263 class MyHandler(socketserver.StreamRequestHandler): 264 pass 265 266 threads = [] 267 for i in range(20): 268 s = MyServer((HOST, 0), MyHandler) 269 t = threading.Thread( 270 name='MyServer serving', 271 target=s.serve_forever, 272 kwargs={'poll_interval':0.01}) 273 t.daemon = True # In case this function raises. 274 threads.append((t, s)) 275 for t, s in threads: 276 t.start() 277 s.shutdown() 278 for t, s in threads: 279 t.join() 280 s.server_close() 281 282 def test_close_immediately(self): 283 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 284 pass 285 286 server = MyServer((HOST, 0), lambda: None) 287 server.server_close() 288 289 def test_tcpserver_bind_leak(self): 290 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 291 # failed. 292 # Create many servers for which bind() will fail, to see if this result 293 # in FD exhaustion. 294 for i in range(1024): 295 with self.assertRaises(OverflowError): 296 socketserver.TCPServer((HOST, -1), 297 socketserver.StreamRequestHandler) 298 299 def test_context_manager(self): 300 with socketserver.TCPServer((HOST, 0), 301 socketserver.StreamRequestHandler) as server: 302 pass 303 self.assertEqual(-1, server.socket.fileno()) 304 305 306class ErrorHandlerTest(unittest.TestCase): 307 """Test that the servers pass normal exceptions from the handler to 308 handle_error(), and that exiting exceptions like SystemExit and 309 KeyboardInterrupt are not passed.""" 310 311 def tearDown(self): 312 os_helper.unlink(os_helper.TESTFN) 313 314 def test_sync_handled(self): 315 BaseErrorTestServer(ValueError) 316 self.check_result(handled=True) 317 318 def test_sync_not_handled(self): 319 with self.assertRaises(SystemExit): 320 BaseErrorTestServer(SystemExit) 321 self.check_result(handled=False) 322 323 def test_threading_handled(self): 324 ThreadingErrorTestServer(ValueError) 325 self.check_result(handled=True) 326 327 def test_threading_not_handled(self): 328 with threading_helper.catch_threading_exception() as cm: 329 ThreadingErrorTestServer(SystemExit) 330 self.check_result(handled=False) 331 332 self.assertIs(cm.exc_type, SystemExit) 333 334 @requires_forking 335 def test_forking_handled(self): 336 ForkingErrorTestServer(ValueError) 337 self.check_result(handled=True) 338 339 @requires_forking 340 def test_forking_not_handled(self): 341 ForkingErrorTestServer(SystemExit) 342 self.check_result(handled=False) 343 344 def check_result(self, handled): 345 with open(os_helper.TESTFN) as log: 346 expected = 'Handler called\n' + 'Error handled\n' * handled 347 self.assertEqual(log.read(), expected) 348 349 350class BaseErrorTestServer(socketserver.TCPServer): 351 def __init__(self, exception): 352 self.exception = exception 353 super().__init__((HOST, 0), BadHandler) 354 with socket.create_connection(self.server_address): 355 pass 356 try: 357 self.handle_request() 358 finally: 359 self.server_close() 360 self.wait_done() 361 362 def handle_error(self, request, client_address): 363 with open(os_helper.TESTFN, 'a') as log: 364 log.write('Error handled\n') 365 366 def wait_done(self): 367 pass 368 369 370class BadHandler(socketserver.BaseRequestHandler): 371 def handle(self): 372 with open(os_helper.TESTFN, 'a') as log: 373 log.write('Handler called\n') 374 raise self.server.exception('Test error') 375 376 377class ThreadingErrorTestServer(socketserver.ThreadingMixIn, 378 BaseErrorTestServer): 379 def __init__(self, *pos, **kw): 380 self.done = threading.Event() 381 super().__init__(*pos, **kw) 382 383 def shutdown_request(self, *pos, **kw): 384 super().shutdown_request(*pos, **kw) 385 self.done.set() 386 387 def wait_done(self): 388 self.done.wait() 389 390 391if HAVE_FORKING: 392 class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer): 393 pass 394 395 396class SocketWriterTest(unittest.TestCase): 397 def test_basics(self): 398 class Handler(socketserver.StreamRequestHandler): 399 def handle(self): 400 self.server.wfile = self.wfile 401 self.server.wfile_fileno = self.wfile.fileno() 402 self.server.request_fileno = self.request.fileno() 403 404 server = socketserver.TCPServer((HOST, 0), Handler) 405 self.addCleanup(server.server_close) 406 s = socket.socket( 407 server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) 408 with s: 409 s.connect(server.server_address) 410 server.handle_request() 411 self.assertIsInstance(server.wfile, io.BufferedIOBase) 412 self.assertEqual(server.wfile_fileno, server.request_fileno) 413 414 def test_write(self): 415 # Test that wfile.write() sends data immediately, and that it does 416 # not truncate sends when interrupted by a Unix signal 417 pthread_kill = test.support.get_attribute(signal, 'pthread_kill') 418 419 class Handler(socketserver.StreamRequestHandler): 420 def handle(self): 421 self.server.sent1 = self.wfile.write(b'write data\n') 422 # Should be sent immediately, without requiring flush() 423 self.server.received = self.rfile.readline() 424 big_chunk = b'\0' * test.support.SOCK_MAX_SIZE 425 self.server.sent2 = self.wfile.write(big_chunk) 426 427 server = socketserver.TCPServer((HOST, 0), Handler) 428 self.addCleanup(server.server_close) 429 interrupted = threading.Event() 430 431 def signal_handler(signum, frame): 432 interrupted.set() 433 434 original = signal.signal(signal.SIGUSR1, signal_handler) 435 self.addCleanup(signal.signal, signal.SIGUSR1, original) 436 response1 = None 437 received2 = None 438 main_thread = threading.get_ident() 439 440 def run_client(): 441 s = socket.socket(server.address_family, socket.SOCK_STREAM, 442 socket.IPPROTO_TCP) 443 with s, s.makefile('rb') as reader: 444 s.connect(server.server_address) 445 nonlocal response1 446 response1 = reader.readline() 447 s.sendall(b'client response\n') 448 449 reader.read(100) 450 # The main thread should now be blocking in a send() syscall. 451 # But in theory, it could get interrupted by other signals, 452 # and then retried. So keep sending the signal in a loop, in 453 # case an earlier signal happens to be delivered at an 454 # inconvenient moment. 455 while True: 456 pthread_kill(main_thread, signal.SIGUSR1) 457 if interrupted.wait(timeout=float(1)): 458 break 459 nonlocal received2 460 received2 = len(reader.read()) 461 462 background = threading.Thread(target=run_client) 463 background.start() 464 server.handle_request() 465 background.join() 466 self.assertEqual(server.sent1, len(response1)) 467 self.assertEqual(response1, b'write data\n') 468 self.assertEqual(server.received, b'client response\n') 469 self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE) 470 self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100) 471 472 473class MiscTestCase(unittest.TestCase): 474 475 def test_all(self): 476 # objects defined in the module should be in __all__ 477 expected = [] 478 for name in dir(socketserver): 479 if not name.startswith('_'): 480 mod_object = getattr(socketserver, name) 481 if getattr(mod_object, '__module__', None) == 'socketserver': 482 expected.append(name) 483 self.assertCountEqual(socketserver.__all__, expected) 484 485 def test_shutdown_request_called_if_verify_request_false(self): 486 # Issue #26309: BaseServer should call shutdown_request even if 487 # verify_request is False 488 489 class MyServer(socketserver.TCPServer): 490 def verify_request(self, request, client_address): 491 return False 492 493 shutdown_called = 0 494 def shutdown_request(self, request): 495 self.shutdown_called += 1 496 socketserver.TCPServer.shutdown_request(self, request) 497 498 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 499 s = socket.socket(server.address_family, socket.SOCK_STREAM) 500 s.connect(server.server_address) 501 s.close() 502 server.handle_request() 503 self.assertEqual(server.shutdown_called, 1) 504 server.server_close() 505 506 def test_threads_reaped(self): 507 """ 508 In #37193, users reported a memory leak 509 due to the saving of every request thread. Ensure that 510 not all threads are kept forever. 511 """ 512 class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer): 513 pass 514 515 server = MyServer((HOST, 0), socketserver.StreamRequestHandler) 516 for n in range(10): 517 with socket.create_connection(server.server_address): 518 server.handle_request() 519 self.assertLess(len(server._threads), 10) 520 server.server_close() 521 522 523if __name__ == "__main__": 524 unittest.main() 525