1""" 2Test suite for SocketServer.py. 3""" 4 5import contextlib 6import imp 7import os 8import select 9import signal 10import socket 11import select 12import errno 13import tempfile 14import unittest 15import SocketServer 16 17import test.test_support 18from test.test_support import reap_children, reap_threads, verbose 19try: 20 import threading 21except ImportError: 22 threading = None 23 24test.test_support.requires("network") 25 26TEST_STR = "hello world\n" 27HOST = test.test_support.HOST 28 29HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") 30requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS, 31 'requires Unix sockets') 32HAVE_FORKING = hasattr(os, "fork") and os.name != "os2" 33requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking') 34 35def signal_alarm(n): 36 """Call signal.alarm when it exists (i.e. not on Windows).""" 37 if hasattr(signal, 'alarm'): 38 signal.alarm(n) 39 40# Remember real select() to avoid interferences with mocking 41_real_select = select.select 42 43def receive(sock, n, timeout=20): 44 r, w, x = _real_select([sock], [], [], timeout) 45 if sock in r: 46 return sock.recv(n) 47 else: 48 raise RuntimeError, "timed out on %r" % (sock,) 49 50if HAVE_UNIX_SOCKETS: 51 class ForkingUnixStreamServer(SocketServer.ForkingMixIn, 52 SocketServer.UnixStreamServer): 53 pass 54 55 class ForkingUnixDatagramServer(SocketServer.ForkingMixIn, 56 SocketServer.UnixDatagramServer): 57 pass 58 59 60@contextlib.contextmanager 61def simple_subprocess(testcase): 62 pid = os.fork() 63 if pid == 0: 64 # Don't raise an exception; it would be caught by the test harness. 65 os._exit(72) 66 yield None 67 pid2, status = os.waitpid(pid, 0) 68 testcase.assertEqual(pid2, pid) 69 testcase.assertEqual(72 << 8, status) 70 71 72def close_server(server): 73 server.server_close() 74 75 if hasattr(server, 'active_children'): 76 # ForkingMixIn: Manually reap all child processes, since server_close() 77 # calls waitpid() in non-blocking mode using the WNOHANG flag. 78 for pid in server.active_children.copy(): 79 try: 80 os.waitpid(pid, 0) 81 except ChildProcessError: 82 pass 83 server.active_children.clear() 84 85 86@unittest.skipUnless(threading, 'Threading required for this test.') 87class SocketServerTest(unittest.TestCase): 88 """Test all socket servers.""" 89 90 def setUp(self): 91 self.addCleanup(signal_alarm, 0) 92 signal_alarm(60) # Kill deadlocks after 60 seconds. 93 self.port_seed = 0 94 self.test_files = [] 95 96 def tearDown(self): 97 self.doCleanups() 98 reap_children() 99 100 for fn in self.test_files: 101 try: 102 os.remove(fn) 103 except os.error: 104 pass 105 self.test_files[:] = [] 106 107 def pickaddr(self, proto): 108 if proto == socket.AF_INET: 109 return (HOST, 0) 110 else: 111 # XXX: We need a way to tell AF_UNIX to pick its own name 112 # like AF_INET provides port==0. 113 dir = None 114 if os.name == 'os2': 115 dir = '\socket' 116 fn = tempfile.mktemp(prefix='unix_socket.', dir=dir) 117 if os.name == 'os2': 118 # AF_UNIX socket names on OS/2 require a specific prefix 119 # which can't include a drive letter and must also use 120 # backslashes as directory separators 121 if fn[1] == ':': 122 fn = fn[2:] 123 if fn[0] in (os.sep, os.altsep): 124 fn = fn[1:] 125 if os.sep == '/': 126 fn = fn.replace(os.sep, os.altsep) 127 else: 128 fn = fn.replace(os.altsep, os.sep) 129 self.test_files.append(fn) 130 return fn 131 132 def make_server(self, addr, svrcls, hdlrbase): 133 class MyServer(svrcls): 134 def handle_error(self, request, client_address): 135 self.close_request(request) 136 close_server(self) 137 raise 138 139 class MyHandler(hdlrbase): 140 def handle(self): 141 line = self.rfile.readline() 142 self.wfile.write(line) 143 144 if verbose: print "creating server" 145 server = MyServer(addr, MyHandler) 146 self.assertEqual(server.server_address, server.socket.getsockname()) 147 return server 148 149 @reap_threads 150 def run_server(self, svrcls, hdlrbase, testfunc): 151 server = self.make_server(self.pickaddr(svrcls.address_family), 152 svrcls, hdlrbase) 153 # We had the OS pick a port, so pull the real address out of 154 # the server. 155 addr = server.server_address 156 if verbose: 157 print "server created" 158 print "ADDR =", addr 159 print "CLASS =", svrcls 160 t = threading.Thread( 161 name='%s serving' % svrcls, 162 target=server.serve_forever, 163 # Short poll interval to make the test finish quickly. 164 # Time between requests is short enough that we won't wake 165 # up spuriously too many times. 166 kwargs={'poll_interval':0.01}) 167 t.daemon = True # In case this function raises. 168 t.start() 169 if verbose: print "server running" 170 for i in range(3): 171 if verbose: print "test client", i 172 testfunc(svrcls.address_family, addr) 173 if verbose: print "waiting for server" 174 server.shutdown() 175 t.join() 176 close_server(server) 177 self.assertRaises(socket.error, server.socket.fileno) 178 if verbose: print "done" 179 180 def stream_examine(self, proto, addr): 181 s = socket.socket(proto, socket.SOCK_STREAM) 182 s.connect(addr) 183 s.sendall(TEST_STR) 184 buf = data = receive(s, 100) 185 while data and '\n' not in buf: 186 data = receive(s, 100) 187 buf += data 188 self.assertEqual(buf, TEST_STR) 189 s.close() 190 191 def dgram_examine(self, proto, addr): 192 s = socket.socket(proto, socket.SOCK_DGRAM) 193 if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: 194 s.bind(self.pickaddr(proto)) 195 s.sendto(TEST_STR, addr) 196 buf = data = receive(s, 100) 197 while data and '\n' not in buf: 198 data = receive(s, 100) 199 buf += data 200 self.assertEqual(buf, TEST_STR) 201 s.close() 202 203 def test_TCPServer(self): 204 self.run_server(SocketServer.TCPServer, 205 SocketServer.StreamRequestHandler, 206 self.stream_examine) 207 208 def test_ThreadingTCPServer(self): 209 self.run_server(SocketServer.ThreadingTCPServer, 210 SocketServer.StreamRequestHandler, 211 self.stream_examine) 212 213 @requires_forking 214 def test_ForkingTCPServer(self): 215 with simple_subprocess(self): 216 self.run_server(SocketServer.ForkingTCPServer, 217 SocketServer.StreamRequestHandler, 218 self.stream_examine) 219 220 @requires_unix_sockets 221 def test_UnixStreamServer(self): 222 self.run_server(SocketServer.UnixStreamServer, 223 SocketServer.StreamRequestHandler, 224 self.stream_examine) 225 226 @requires_unix_sockets 227 def test_ThreadingUnixStreamServer(self): 228 self.run_server(SocketServer.ThreadingUnixStreamServer, 229 SocketServer.StreamRequestHandler, 230 self.stream_examine) 231 232 @requires_unix_sockets 233 @requires_forking 234 def test_ForkingUnixStreamServer(self): 235 with simple_subprocess(self): 236 self.run_server(ForkingUnixStreamServer, 237 SocketServer.StreamRequestHandler, 238 self.stream_examine) 239 240 def test_UDPServer(self): 241 self.run_server(SocketServer.UDPServer, 242 SocketServer.DatagramRequestHandler, 243 self.dgram_examine) 244 245 def test_ThreadingUDPServer(self): 246 self.run_server(SocketServer.ThreadingUDPServer, 247 SocketServer.DatagramRequestHandler, 248 self.dgram_examine) 249 250 @requires_forking 251 def test_ForkingUDPServer(self): 252 with simple_subprocess(self): 253 self.run_server(SocketServer.ForkingUDPServer, 254 SocketServer.DatagramRequestHandler, 255 self.dgram_examine) 256 257 @contextlib.contextmanager 258 def mocked_select_module(self): 259 """Mocks the select.select() call to raise EINTR for first call""" 260 old_select = select.select 261 262 class MockSelect: 263 def __init__(self): 264 self.called = 0 265 266 def __call__(self, *args): 267 self.called += 1 268 if self.called == 1: 269 # raise the exception on first call 270 raise select.error(errno.EINTR, os.strerror(errno.EINTR)) 271 else: 272 # Return real select value for consecutive calls 273 return old_select(*args) 274 275 select.select = MockSelect() 276 try: 277 yield select.select 278 finally: 279 select.select = old_select 280 281 def test_InterruptServerSelectCall(self): 282 with self.mocked_select_module() as mock_select: 283 pid = self.run_server(SocketServer.TCPServer, 284 SocketServer.StreamRequestHandler, 285 self.stream_examine) 286 # Make sure select was called again: 287 self.assertGreater(mock_select.called, 1) 288 289 @requires_unix_sockets 290 def test_UnixDatagramServer(self): 291 self.run_server(SocketServer.UnixDatagramServer, 292 SocketServer.DatagramRequestHandler, 293 self.dgram_examine) 294 295 @requires_unix_sockets 296 def test_ThreadingUnixDatagramServer(self): 297 self.run_server(SocketServer.ThreadingUnixDatagramServer, 298 SocketServer.DatagramRequestHandler, 299 self.dgram_examine) 300 301 @requires_unix_sockets 302 @requires_forking 303 def test_ForkingUnixDatagramServer(self): 304 self.run_server(ForkingUnixDatagramServer, 305 SocketServer.DatagramRequestHandler, 306 self.dgram_examine) 307 308 @reap_threads 309 def test_shutdown(self): 310 # Issue #2302: shutdown() should always succeed in making an 311 # other thread leave serve_forever(). 312 class MyServer(SocketServer.TCPServer): 313 pass 314 315 class MyHandler(SocketServer.StreamRequestHandler): 316 pass 317 318 threads = [] 319 for i in range(20): 320 s = MyServer((HOST, 0), MyHandler) 321 t = threading.Thread( 322 name='MyServer serving', 323 target=s.serve_forever, 324 kwargs={'poll_interval':0.01}) 325 t.daemon = True # In case this function raises. 326 threads.append((t, s)) 327 for t, s in threads: 328 t.start() 329 s.shutdown() 330 for t, s in threads: 331 t.join() 332 close_server(s) 333 334 def test_tcpserver_bind_leak(self): 335 # Issue #22435: the server socket wouldn't be closed if bind()/listen() 336 # failed. 337 # Create many servers for which bind() will fail, to see if this result 338 # in FD exhaustion. 339 for i in range(1024): 340 with self.assertRaises(OverflowError): 341 SocketServer.TCPServer((HOST, -1), 342 SocketServer.StreamRequestHandler) 343 344 345class MiscTestCase(unittest.TestCase): 346 347 def test_shutdown_request_called_if_verify_request_false(self): 348 # Issue #26309: BaseServer should call shutdown_request even if 349 # verify_request is False 350 351 class MyServer(SocketServer.TCPServer): 352 def verify_request(self, request, client_address): 353 return False 354 355 shutdown_called = 0 356 def shutdown_request(self, request): 357 self.shutdown_called += 1 358 SocketServer.TCPServer.shutdown_request(self, request) 359 360 server = MyServer((HOST, 0), SocketServer.StreamRequestHandler) 361 s = socket.socket(server.address_family, socket.SOCK_STREAM) 362 s.connect(server.server_address) 363 s.close() 364 server.handle_request() 365 self.assertEqual(server.shutdown_called, 1) 366 close_server(server) 367 368 369def test_main(): 370 if imp.lock_held(): 371 # If the import lock is held, the threads will hang 372 raise unittest.SkipTest("can't run when import lock is held") 373 374 test.test_support.run_unittest(SocketServerTest) 375 376if __name__ == "__main__": 377 test_main() 378