1"""
2Test suite for SocketServer.py.
3"""
4
5import contextlib
6import imp
7import os
8import select
9import signal
10import socket
11import select
12import errno
13import tempfile
14import unittest
15import SocketServer
16
17import test.test_support
18from test.test_support import reap_children, reap_threads, verbose
19try:
20    import threading
21except ImportError:
22    threading = None
23
24test.test_support.requires("network")
25
26TEST_STR = "hello world\n"
27HOST = test.test_support.HOST
28
29HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
30requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
31                                            'requires Unix sockets')
32HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
33requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
34
35def signal_alarm(n):
36    """Call signal.alarm when it exists (i.e. not on Windows)."""
37    if hasattr(signal, 'alarm'):
38        signal.alarm(n)
39
40# Remember real select() to avoid interferences with mocking
41_real_select = select.select
42
43def receive(sock, n, timeout=20):
44    r, w, x = _real_select([sock], [], [], timeout)
45    if sock in r:
46        return sock.recv(n)
47    else:
48        raise RuntimeError, "timed out on %r" % (sock,)
49
50if HAVE_UNIX_SOCKETS:
51    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
52                                  SocketServer.UnixStreamServer):
53        pass
54
55    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
56                                    SocketServer.UnixDatagramServer):
57        pass
58
59
60@contextlib.contextmanager
61def simple_subprocess(testcase):
62    pid = os.fork()
63    if pid == 0:
64        # Don't raise an exception; it would be caught by the test harness.
65        os._exit(72)
66    yield None
67    pid2, status = os.waitpid(pid, 0)
68    testcase.assertEqual(pid2, pid)
69    testcase.assertEqual(72 << 8, status)
70
71
72def close_server(server):
73    server.server_close()
74
75    if hasattr(server, 'active_children'):
76        # ForkingMixIn: Manually reap all child processes, since server_close()
77        # calls waitpid() in non-blocking mode using the WNOHANG flag.
78        for pid in server.active_children.copy():
79            try:
80                os.waitpid(pid, 0)
81            except ChildProcessError:
82                pass
83        server.active_children.clear()
84
85
86@unittest.skipUnless(threading, 'Threading required for this test.')
87class SocketServerTest(unittest.TestCase):
88    """Test all socket servers."""
89
90    def setUp(self):
91        self.addCleanup(signal_alarm, 0)
92        signal_alarm(60)  # Kill deadlocks after 60 seconds.
93        self.port_seed = 0
94        self.test_files = []
95
96    def tearDown(self):
97        self.doCleanups()
98        reap_children()
99
100        for fn in self.test_files:
101            try:
102                os.remove(fn)
103            except os.error:
104                pass
105        self.test_files[:] = []
106
107    def pickaddr(self, proto):
108        if proto == socket.AF_INET:
109            return (HOST, 0)
110        else:
111            # XXX: We need a way to tell AF_UNIX to pick its own name
112            # like AF_INET provides port==0.
113            dir = None
114            if os.name == 'os2':
115                dir = '\socket'
116            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
117            if os.name == 'os2':
118                # AF_UNIX socket names on OS/2 require a specific prefix
119                # which can't include a drive letter and must also use
120                # backslashes as directory separators
121                if fn[1] == ':':
122                    fn = fn[2:]
123                if fn[0] in (os.sep, os.altsep):
124                    fn = fn[1:]
125                if os.sep == '/':
126                    fn = fn.replace(os.sep, os.altsep)
127                else:
128                    fn = fn.replace(os.altsep, os.sep)
129            self.test_files.append(fn)
130            return fn
131
132    def make_server(self, addr, svrcls, hdlrbase):
133        class MyServer(svrcls):
134            def handle_error(self, request, client_address):
135                self.close_request(request)
136                close_server(self)
137                raise
138
139        class MyHandler(hdlrbase):
140            def handle(self):
141                line = self.rfile.readline()
142                self.wfile.write(line)
143
144        if verbose: print "creating server"
145        server = MyServer(addr, MyHandler)
146        self.assertEqual(server.server_address, server.socket.getsockname())
147        return server
148
149    @reap_threads
150    def run_server(self, svrcls, hdlrbase, testfunc):
151        server = self.make_server(self.pickaddr(svrcls.address_family),
152                                  svrcls, hdlrbase)
153        # We had the OS pick a port, so pull the real address out of
154        # the server.
155        addr = server.server_address
156        if verbose:
157            print "server created"
158            print "ADDR =", addr
159            print "CLASS =", svrcls
160        t = threading.Thread(
161            name='%s serving' % svrcls,
162            target=server.serve_forever,
163            # Short poll interval to make the test finish quickly.
164            # Time between requests is short enough that we won't wake
165            # up spuriously too many times.
166            kwargs={'poll_interval':0.01})
167        t.daemon = True  # In case this function raises.
168        t.start()
169        if verbose: print "server running"
170        for i in range(3):
171            if verbose: print "test client", i
172            testfunc(svrcls.address_family, addr)
173        if verbose: print "waiting for server"
174        server.shutdown()
175        t.join()
176        close_server(server)
177        self.assertRaises(socket.error, server.socket.fileno)
178        if verbose: print "done"
179
180    def stream_examine(self, proto, addr):
181        s = socket.socket(proto, socket.SOCK_STREAM)
182        s.connect(addr)
183        s.sendall(TEST_STR)
184        buf = data = receive(s, 100)
185        while data and '\n' not in buf:
186            data = receive(s, 100)
187            buf += data
188        self.assertEqual(buf, TEST_STR)
189        s.close()
190
191    def dgram_examine(self, proto, addr):
192        s = socket.socket(proto, socket.SOCK_DGRAM)
193        if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
194            s.bind(self.pickaddr(proto))
195        s.sendto(TEST_STR, addr)
196        buf = data = receive(s, 100)
197        while data and '\n' not in buf:
198            data = receive(s, 100)
199            buf += data
200        self.assertEqual(buf, TEST_STR)
201        s.close()
202
203    def test_TCPServer(self):
204        self.run_server(SocketServer.TCPServer,
205                        SocketServer.StreamRequestHandler,
206                        self.stream_examine)
207
208    def test_ThreadingTCPServer(self):
209        self.run_server(SocketServer.ThreadingTCPServer,
210                        SocketServer.StreamRequestHandler,
211                        self.stream_examine)
212
213    @requires_forking
214    def test_ForkingTCPServer(self):
215        with simple_subprocess(self):
216            self.run_server(SocketServer.ForkingTCPServer,
217                            SocketServer.StreamRequestHandler,
218                            self.stream_examine)
219
220    @requires_unix_sockets
221    def test_UnixStreamServer(self):
222        self.run_server(SocketServer.UnixStreamServer,
223                        SocketServer.StreamRequestHandler,
224                        self.stream_examine)
225
226    @requires_unix_sockets
227    def test_ThreadingUnixStreamServer(self):
228        self.run_server(SocketServer.ThreadingUnixStreamServer,
229                        SocketServer.StreamRequestHandler,
230                        self.stream_examine)
231
232    @requires_unix_sockets
233    @requires_forking
234    def test_ForkingUnixStreamServer(self):
235        with simple_subprocess(self):
236            self.run_server(ForkingUnixStreamServer,
237                            SocketServer.StreamRequestHandler,
238                            self.stream_examine)
239
240    def test_UDPServer(self):
241        self.run_server(SocketServer.UDPServer,
242                        SocketServer.DatagramRequestHandler,
243                        self.dgram_examine)
244
245    def test_ThreadingUDPServer(self):
246        self.run_server(SocketServer.ThreadingUDPServer,
247                        SocketServer.DatagramRequestHandler,
248                        self.dgram_examine)
249
250    @requires_forking
251    def test_ForkingUDPServer(self):
252        with simple_subprocess(self):
253            self.run_server(SocketServer.ForkingUDPServer,
254                            SocketServer.DatagramRequestHandler,
255                            self.dgram_examine)
256
257    @contextlib.contextmanager
258    def mocked_select_module(self):
259        """Mocks the select.select() call to raise EINTR for first call"""
260        old_select = select.select
261
262        class MockSelect:
263            def __init__(self):
264                self.called = 0
265
266            def __call__(self, *args):
267                self.called += 1
268                if self.called == 1:
269                    # raise the exception on first call
270                    raise select.error(errno.EINTR, os.strerror(errno.EINTR))
271                else:
272                    # Return real select value for consecutive calls
273                    return old_select(*args)
274
275        select.select = MockSelect()
276        try:
277            yield select.select
278        finally:
279            select.select = old_select
280
281    def test_InterruptServerSelectCall(self):
282        with self.mocked_select_module() as mock_select:
283            pid = self.run_server(SocketServer.TCPServer,
284                                  SocketServer.StreamRequestHandler,
285                                  self.stream_examine)
286            # Make sure select was called again:
287            self.assertGreater(mock_select.called, 1)
288
289    @requires_unix_sockets
290    def test_UnixDatagramServer(self):
291        self.run_server(SocketServer.UnixDatagramServer,
292                        SocketServer.DatagramRequestHandler,
293                        self.dgram_examine)
294
295    @requires_unix_sockets
296    def test_ThreadingUnixDatagramServer(self):
297        self.run_server(SocketServer.ThreadingUnixDatagramServer,
298                        SocketServer.DatagramRequestHandler,
299                        self.dgram_examine)
300
301    @requires_unix_sockets
302    @requires_forking
303    def test_ForkingUnixDatagramServer(self):
304        self.run_server(ForkingUnixDatagramServer,
305                        SocketServer.DatagramRequestHandler,
306                        self.dgram_examine)
307
308    @reap_threads
309    def test_shutdown(self):
310        # Issue #2302: shutdown() should always succeed in making an
311        # other thread leave serve_forever().
312        class MyServer(SocketServer.TCPServer):
313            pass
314
315        class MyHandler(SocketServer.StreamRequestHandler):
316            pass
317
318        threads = []
319        for i in range(20):
320            s = MyServer((HOST, 0), MyHandler)
321            t = threading.Thread(
322                name='MyServer serving',
323                target=s.serve_forever,
324                kwargs={'poll_interval':0.01})
325            t.daemon = True  # In case this function raises.
326            threads.append((t, s))
327        for t, s in threads:
328            t.start()
329            s.shutdown()
330        for t, s in threads:
331            t.join()
332            close_server(s)
333
334    def test_tcpserver_bind_leak(self):
335        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
336        # failed.
337        # Create many servers for which bind() will fail, to see if this result
338        # in FD exhaustion.
339        for i in range(1024):
340            with self.assertRaises(OverflowError):
341                SocketServer.TCPServer((HOST, -1),
342                                       SocketServer.StreamRequestHandler)
343
344
345class MiscTestCase(unittest.TestCase):
346
347    def test_shutdown_request_called_if_verify_request_false(self):
348        # Issue #26309: BaseServer should call shutdown_request even if
349        # verify_request is False
350
351        class MyServer(SocketServer.TCPServer):
352            def verify_request(self, request, client_address):
353                return False
354
355            shutdown_called = 0
356            def shutdown_request(self, request):
357                self.shutdown_called += 1
358                SocketServer.TCPServer.shutdown_request(self, request)
359
360        server = MyServer((HOST, 0), SocketServer.StreamRequestHandler)
361        s = socket.socket(server.address_family, socket.SOCK_STREAM)
362        s.connect(server.server_address)
363        s.close()
364        server.handle_request()
365        self.assertEqual(server.shutdown_called, 1)
366        close_server(server)
367
368
369def test_main():
370    if imp.lock_held():
371        # If the import lock is held, the threads will hang
372        raise unittest.SkipTest("can't run when import lock is held")
373
374    test.test_support.run_unittest(SocketServerTest)
375
376if __name__ == "__main__":
377    test_main()
378