1"""
2Test suite for socketserver.
3"""
4
5import contextlib
6import io
7import os
8import select
9import signal
10import socket
11import tempfile
12import threading
13import unittest
14import socketserver
15
16import test.support
17from test.support import reap_children, verbose
18from test.support import os_helper
19from test.support import socket_helper
20from test.support import threading_helper
21
22
23test.support.requires("network")
24test.support.requires_working_socket(module=True)
25
26
27TEST_STR = b"hello world\n"
28HOST = socket_helper.HOST
29
30HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
31requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
32                                            'requires Unix sockets')
33HAVE_FORKING = test.support.has_fork_support
34requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
35
36def signal_alarm(n):
37    """Call signal.alarm when it exists (i.e. not on Windows)."""
38    if hasattr(signal, 'alarm'):
39        signal.alarm(n)
40
41# Remember real select() to avoid interferences with mocking
42_real_select = select.select
43
44def receive(sock, n, timeout=test.support.SHORT_TIMEOUT):
45    r, w, x = _real_select([sock], [], [], timeout)
46    if sock in r:
47        return sock.recv(n)
48    else:
49        raise RuntimeError("timed out on %r" % (sock,))
50
51if HAVE_UNIX_SOCKETS and HAVE_FORKING:
52    class ForkingUnixStreamServer(socketserver.ForkingMixIn,
53                                  socketserver.UnixStreamServer):
54        pass
55
56    class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
57                                    socketserver.UnixDatagramServer):
58        pass
59
60
61@contextlib.contextmanager
62def simple_subprocess(testcase):
63    """Tests that a custom child process is not waited on (Issue 1540386)"""
64    pid = os.fork()
65    if pid == 0:
66        # Don't raise an exception; it would be caught by the test harness.
67        os._exit(72)
68    try:
69        yield None
70    except:
71        raise
72    finally:
73        test.support.wait_process(pid, exitcode=72)
74
75
76class SocketServerTest(unittest.TestCase):
77    """Test all socket servers."""
78
79    def setUp(self):
80        signal_alarm(60)  # Kill deadlocks after 60 seconds.
81        self.port_seed = 0
82        self.test_files = []
83
84    def tearDown(self):
85        signal_alarm(0)  # Didn't deadlock.
86        reap_children()
87
88        for fn in self.test_files:
89            try:
90                os.remove(fn)
91            except OSError:
92                pass
93        self.test_files[:] = []
94
95    def pickaddr(self, proto):
96        if proto == socket.AF_INET:
97            return (HOST, 0)
98        else:
99            # XXX: We need a way to tell AF_UNIX to pick its own name
100            # like AF_INET provides port==0.
101            dir = None
102            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
103            self.test_files.append(fn)
104            return fn
105
106    def make_server(self, addr, svrcls, hdlrbase):
107        class MyServer(svrcls):
108            def handle_error(self, request, client_address):
109                self.close_request(request)
110                raise
111
112        class MyHandler(hdlrbase):
113            def handle(self):
114                line = self.rfile.readline()
115                self.wfile.write(line)
116
117        if verbose: print("creating server")
118        try:
119            server = MyServer(addr, MyHandler)
120        except PermissionError as e:
121            # Issue 29184: cannot bind() a Unix socket on Android.
122            self.skipTest('Cannot create server (%s, %s): %s' %
123                          (svrcls, addr, e))
124        self.assertEqual(server.server_address, server.socket.getsockname())
125        return server
126
127    @threading_helper.reap_threads
128    def run_server(self, svrcls, hdlrbase, testfunc):
129        server = self.make_server(self.pickaddr(svrcls.address_family),
130                                  svrcls, hdlrbase)
131        # We had the OS pick a port, so pull the real address out of
132        # the server.
133        addr = server.server_address
134        if verbose:
135            print("ADDR =", addr)
136            print("CLASS =", svrcls)
137
138        t = threading.Thread(
139            name='%s serving' % svrcls,
140            target=server.serve_forever,
141            # Short poll interval to make the test finish quickly.
142            # Time between requests is short enough that we won't wake
143            # up spuriously too many times.
144            kwargs={'poll_interval':0.01})
145        t.daemon = True  # In case this function raises.
146        t.start()
147        if verbose: print("server running")
148        for i in range(3):
149            if verbose: print("test client", i)
150            testfunc(svrcls.address_family, addr)
151        if verbose: print("waiting for server")
152        server.shutdown()
153        t.join()
154        server.server_close()
155        self.assertEqual(-1, server.socket.fileno())
156        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
157            # bpo-31151: Check that ForkingMixIn.server_close() waits until
158            # all children completed
159            self.assertFalse(server.active_children)
160        if verbose: print("done")
161
162    def stream_examine(self, proto, addr):
163        with socket.socket(proto, socket.SOCK_STREAM) as s:
164            s.connect(addr)
165            s.sendall(TEST_STR)
166            buf = data = receive(s, 100)
167            while data and b'\n' not in buf:
168                data = receive(s, 100)
169                buf += data
170            self.assertEqual(buf, TEST_STR)
171
172    def dgram_examine(self, proto, addr):
173        with socket.socket(proto, socket.SOCK_DGRAM) as s:
174            if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
175                s.bind(self.pickaddr(proto))
176            s.sendto(TEST_STR, addr)
177            buf = data = receive(s, 100)
178            while data and b'\n' not in buf:
179                data = receive(s, 100)
180                buf += data
181            self.assertEqual(buf, TEST_STR)
182
183    def test_TCPServer(self):
184        self.run_server(socketserver.TCPServer,
185                        socketserver.StreamRequestHandler,
186                        self.stream_examine)
187
188    def test_ThreadingTCPServer(self):
189        self.run_server(socketserver.ThreadingTCPServer,
190                        socketserver.StreamRequestHandler,
191                        self.stream_examine)
192
193    @requires_forking
194    def test_ForkingTCPServer(self):
195        with simple_subprocess(self):
196            self.run_server(socketserver.ForkingTCPServer,
197                            socketserver.StreamRequestHandler,
198                            self.stream_examine)
199
200    @requires_unix_sockets
201    def test_UnixStreamServer(self):
202        self.run_server(socketserver.UnixStreamServer,
203                        socketserver.StreamRequestHandler,
204                        self.stream_examine)
205
206    @requires_unix_sockets
207    def test_ThreadingUnixStreamServer(self):
208        self.run_server(socketserver.ThreadingUnixStreamServer,
209                        socketserver.StreamRequestHandler,
210                        self.stream_examine)
211
212    @requires_unix_sockets
213    @requires_forking
214    def test_ForkingUnixStreamServer(self):
215        with simple_subprocess(self):
216            self.run_server(ForkingUnixStreamServer,
217                            socketserver.StreamRequestHandler,
218                            self.stream_examine)
219
220    def test_UDPServer(self):
221        self.run_server(socketserver.UDPServer,
222                        socketserver.DatagramRequestHandler,
223                        self.dgram_examine)
224
225    def test_ThreadingUDPServer(self):
226        self.run_server(socketserver.ThreadingUDPServer,
227                        socketserver.DatagramRequestHandler,
228                        self.dgram_examine)
229
230    @requires_forking
231    def test_ForkingUDPServer(self):
232        with simple_subprocess(self):
233            self.run_server(socketserver.ForkingUDPServer,
234                            socketserver.DatagramRequestHandler,
235                            self.dgram_examine)
236
237    @requires_unix_sockets
238    def test_UnixDatagramServer(self):
239        self.run_server(socketserver.UnixDatagramServer,
240                        socketserver.DatagramRequestHandler,
241                        self.dgram_examine)
242
243    @requires_unix_sockets
244    def test_ThreadingUnixDatagramServer(self):
245        self.run_server(socketserver.ThreadingUnixDatagramServer,
246                        socketserver.DatagramRequestHandler,
247                        self.dgram_examine)
248
249    @requires_unix_sockets
250    @requires_forking
251    def test_ForkingUnixDatagramServer(self):
252        self.run_server(ForkingUnixDatagramServer,
253                        socketserver.DatagramRequestHandler,
254                        self.dgram_examine)
255
256    @threading_helper.reap_threads
257    def test_shutdown(self):
258        # Issue #2302: shutdown() should always succeed in making an
259        # other thread leave serve_forever().
260        class MyServer(socketserver.TCPServer):
261            pass
262
263        class MyHandler(socketserver.StreamRequestHandler):
264            pass
265
266        threads = []
267        for i in range(20):
268            s = MyServer((HOST, 0), MyHandler)
269            t = threading.Thread(
270                name='MyServer serving',
271                target=s.serve_forever,
272                kwargs={'poll_interval':0.01})
273            t.daemon = True  # In case this function raises.
274            threads.append((t, s))
275        for t, s in threads:
276            t.start()
277            s.shutdown()
278        for t, s in threads:
279            t.join()
280            s.server_close()
281
282    def test_close_immediately(self):
283        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
284            pass
285
286        server = MyServer((HOST, 0), lambda: None)
287        server.server_close()
288
289    def test_tcpserver_bind_leak(self):
290        # Issue #22435: the server socket wouldn't be closed if bind()/listen()
291        # failed.
292        # Create many servers for which bind() will fail, to see if this result
293        # in FD exhaustion.
294        for i in range(1024):
295            with self.assertRaises(OverflowError):
296                socketserver.TCPServer((HOST, -1),
297                                       socketserver.StreamRequestHandler)
298
299    def test_context_manager(self):
300        with socketserver.TCPServer((HOST, 0),
301                                    socketserver.StreamRequestHandler) as server:
302            pass
303        self.assertEqual(-1, server.socket.fileno())
304
305
306class ErrorHandlerTest(unittest.TestCase):
307    """Test that the servers pass normal exceptions from the handler to
308    handle_error(), and that exiting exceptions like SystemExit and
309    KeyboardInterrupt are not passed."""
310
311    def tearDown(self):
312        os_helper.unlink(os_helper.TESTFN)
313
314    def test_sync_handled(self):
315        BaseErrorTestServer(ValueError)
316        self.check_result(handled=True)
317
318    def test_sync_not_handled(self):
319        with self.assertRaises(SystemExit):
320            BaseErrorTestServer(SystemExit)
321        self.check_result(handled=False)
322
323    def test_threading_handled(self):
324        ThreadingErrorTestServer(ValueError)
325        self.check_result(handled=True)
326
327    def test_threading_not_handled(self):
328        with threading_helper.catch_threading_exception() as cm:
329            ThreadingErrorTestServer(SystemExit)
330            self.check_result(handled=False)
331
332            self.assertIs(cm.exc_type, SystemExit)
333
334    @requires_forking
335    def test_forking_handled(self):
336        ForkingErrorTestServer(ValueError)
337        self.check_result(handled=True)
338
339    @requires_forking
340    def test_forking_not_handled(self):
341        ForkingErrorTestServer(SystemExit)
342        self.check_result(handled=False)
343
344    def check_result(self, handled):
345        with open(os_helper.TESTFN) as log:
346            expected = 'Handler called\n' + 'Error handled\n' * handled
347            self.assertEqual(log.read(), expected)
348
349
350class BaseErrorTestServer(socketserver.TCPServer):
351    def __init__(self, exception):
352        self.exception = exception
353        super().__init__((HOST, 0), BadHandler)
354        with socket.create_connection(self.server_address):
355            pass
356        try:
357            self.handle_request()
358        finally:
359            self.server_close()
360        self.wait_done()
361
362    def handle_error(self, request, client_address):
363        with open(os_helper.TESTFN, 'a') as log:
364            log.write('Error handled\n')
365
366    def wait_done(self):
367        pass
368
369
370class BadHandler(socketserver.BaseRequestHandler):
371    def handle(self):
372        with open(os_helper.TESTFN, 'a') as log:
373            log.write('Handler called\n')
374        raise self.server.exception('Test error')
375
376
377class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
378        BaseErrorTestServer):
379    def __init__(self, *pos, **kw):
380        self.done = threading.Event()
381        super().__init__(*pos, **kw)
382
383    def shutdown_request(self, *pos, **kw):
384        super().shutdown_request(*pos, **kw)
385        self.done.set()
386
387    def wait_done(self):
388        self.done.wait()
389
390
391if HAVE_FORKING:
392    class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
393        pass
394
395
396class SocketWriterTest(unittest.TestCase):
397    def test_basics(self):
398        class Handler(socketserver.StreamRequestHandler):
399            def handle(self):
400                self.server.wfile = self.wfile
401                self.server.wfile_fileno = self.wfile.fileno()
402                self.server.request_fileno = self.request.fileno()
403
404        server = socketserver.TCPServer((HOST, 0), Handler)
405        self.addCleanup(server.server_close)
406        s = socket.socket(
407            server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
408        with s:
409            s.connect(server.server_address)
410        server.handle_request()
411        self.assertIsInstance(server.wfile, io.BufferedIOBase)
412        self.assertEqual(server.wfile_fileno, server.request_fileno)
413
414    def test_write(self):
415        # Test that wfile.write() sends data immediately, and that it does
416        # not truncate sends when interrupted by a Unix signal
417        pthread_kill = test.support.get_attribute(signal, 'pthread_kill')
418
419        class Handler(socketserver.StreamRequestHandler):
420            def handle(self):
421                self.server.sent1 = self.wfile.write(b'write data\n')
422                # Should be sent immediately, without requiring flush()
423                self.server.received = self.rfile.readline()
424                big_chunk = b'\0' * test.support.SOCK_MAX_SIZE
425                self.server.sent2 = self.wfile.write(big_chunk)
426
427        server = socketserver.TCPServer((HOST, 0), Handler)
428        self.addCleanup(server.server_close)
429        interrupted = threading.Event()
430
431        def signal_handler(signum, frame):
432            interrupted.set()
433
434        original = signal.signal(signal.SIGUSR1, signal_handler)
435        self.addCleanup(signal.signal, signal.SIGUSR1, original)
436        response1 = None
437        received2 = None
438        main_thread = threading.get_ident()
439
440        def run_client():
441            s = socket.socket(server.address_family, socket.SOCK_STREAM,
442                socket.IPPROTO_TCP)
443            with s, s.makefile('rb') as reader:
444                s.connect(server.server_address)
445                nonlocal response1
446                response1 = reader.readline()
447                s.sendall(b'client response\n')
448
449                reader.read(100)
450                # The main thread should now be blocking in a send() syscall.
451                # But in theory, it could get interrupted by other signals,
452                # and then retried. So keep sending the signal in a loop, in
453                # case an earlier signal happens to be delivered at an
454                # inconvenient moment.
455                while True:
456                    pthread_kill(main_thread, signal.SIGUSR1)
457                    if interrupted.wait(timeout=float(1)):
458                        break
459                nonlocal received2
460                received2 = len(reader.read())
461
462        background = threading.Thread(target=run_client)
463        background.start()
464        server.handle_request()
465        background.join()
466        self.assertEqual(server.sent1, len(response1))
467        self.assertEqual(response1, b'write data\n')
468        self.assertEqual(server.received, b'client response\n')
469        self.assertEqual(server.sent2, test.support.SOCK_MAX_SIZE)
470        self.assertEqual(received2, test.support.SOCK_MAX_SIZE - 100)
471
472
473class MiscTestCase(unittest.TestCase):
474
475    def test_all(self):
476        # objects defined in the module should be in __all__
477        expected = []
478        for name in dir(socketserver):
479            if not name.startswith('_'):
480                mod_object = getattr(socketserver, name)
481                if getattr(mod_object, '__module__', None) == 'socketserver':
482                    expected.append(name)
483        self.assertCountEqual(socketserver.__all__, expected)
484
485    def test_shutdown_request_called_if_verify_request_false(self):
486        # Issue #26309: BaseServer should call shutdown_request even if
487        # verify_request is False
488
489        class MyServer(socketserver.TCPServer):
490            def verify_request(self, request, client_address):
491                return False
492
493            shutdown_called = 0
494            def shutdown_request(self, request):
495                self.shutdown_called += 1
496                socketserver.TCPServer.shutdown_request(self, request)
497
498        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
499        s = socket.socket(server.address_family, socket.SOCK_STREAM)
500        s.connect(server.server_address)
501        s.close()
502        server.handle_request()
503        self.assertEqual(server.shutdown_called, 1)
504        server.server_close()
505
506    def test_threads_reaped(self):
507        """
508        In #37193, users reported a memory leak
509        due to the saving of every request thread. Ensure that
510        not all threads are kept forever.
511        """
512        class MyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
513            pass
514
515        server = MyServer((HOST, 0), socketserver.StreamRequestHandler)
516        for n in range(10):
517            with socket.create_connection(server.server_address):
518                server.handle_request()
519        self.assertLess(len(server._threads), 10)
520        server.server_close()
521
522
523if __name__ == "__main__":
524    unittest.main()
525