1#!/usr/bin/python
2#
3# Copyright 2007 Google Inc. All Rights Reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17"""Unittests for the portpicker module."""
18
19from __future__ import print_function
20import errno
21import os
22import random
23import socket
24import sys
25import unittest
26from contextlib import ExitStack
27
28if sys.platform == 'win32':
29    import _winapi
30else:
31    _winapi = None
32
33try:
34    # pylint: disable=no-name-in-module
35    from unittest import mock  # Python >= 3.3.
36except ImportError:
37    import mock  # https://pypi.python.org/pypi/mock
38
39import portpicker
40
41
42class PickUnusedPortTest(unittest.TestCase):
43    def IsUnusedTCPPort(self, port):
44        return self._bind(port, socket.SOCK_STREAM, socket.IPPROTO_TCP)
45
46    def IsUnusedUDPPort(self, port):
47        return self._bind(port, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
48
49    def setUp(self):
50        # So we can Bind even if portpicker.bind is stubbed out.
51        self._bind = portpicker.bind
52        portpicker._owned_ports.clear()
53        portpicker._free_ports.clear()
54        portpicker._random_ports.clear()
55
56    def testPickUnusedPortActuallyWorks(self):
57        """This test can be flaky."""
58        for _ in range(10):
59            port = portpicker.pick_unused_port()
60            self.assertTrue(self.IsUnusedTCPPort(port))
61            self.assertTrue(self.IsUnusedUDPPort(port))
62
63    @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
64                     'no port server to test against')
65    def testPickUnusedCanSuccessfullyUsePortServer(self):
66
67        with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
68            portpicker._pick_unused_port_without_server.side_effect = (
69                Exception('eek!')
70            )
71
72            # Since _PickUnusedPortWithoutServer() raises an exception, if we
73            # can successfully obtain a port, the portserver must be working.
74            port = portpicker.pick_unused_port()
75            self.assertTrue(self.IsUnusedTCPPort(port))
76            self.assertTrue(self.IsUnusedUDPPort(port))
77
78    @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
79                     'no port server to test against')
80    def testPickUnusedCanSuccessfullyUsePortServerAddressKwarg(self):
81
82        with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
83            portpicker._pick_unused_port_without_server.side_effect = (
84                Exception('eek!')
85            )
86
87            # Since _PickUnusedPortWithoutServer() raises an exception, and
88            # we've temporarily removed PORTSERVER_ADDRESS from os.environ, if
89            # we can successfully obtain a port, the portserver must be working.
90            addr = os.environ.pop('PORTSERVER_ADDRESS')
91            try:
92                port = portpicker.pick_unused_port(portserver_address=addr)
93                self.assertTrue(self.IsUnusedTCPPort(port))
94                self.assertTrue(self.IsUnusedUDPPort(port))
95            finally:
96              os.environ['PORTSERVER_ADDRESS'] = addr
97
98    @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ,
99                     'no port server to test against')
100    def testGetPortFromPortServer(self):
101        """Exercise the get_port_from_port_server() helper function."""
102        for _ in range(10):
103            port = portpicker.get_port_from_port_server(
104                os.environ['PORTSERVER_ADDRESS'])
105            self.assertTrue(self.IsUnusedTCPPort(port))
106            self.assertTrue(self.IsUnusedUDPPort(port))
107
108    def testSendsPidToPortServer(self):
109        with ExitStack() as stack:
110            if _winapi:
111                create_file_mock = mock.Mock()
112                create_file_mock.return_value = 0
113                read_file_mock = mock.Mock()
114                write_file_mock = mock.Mock()
115                read_file_mock.return_value = (b'42768\n', 0)
116                stack.enter_context(
117                    mock.patch('_winapi.CreateFile', new=create_file_mock))
118                stack.enter_context(
119                    mock.patch('_winapi.WriteFile', new=write_file_mock))
120                stack.enter_context(
121                    mock.patch('_winapi.ReadFile', new=read_file_mock))
122                port = portpicker.get_port_from_port_server(
123                    'portserver', pid=1234)
124                write_file_mock.assert_called_once_with(0, b'1234\n')
125            else:
126                server = mock.Mock()
127                server.recv.return_value = b'42768\n'
128                stack.enter_context(
129                    mock.patch.object(socket, 'socket', return_value=server))
130                port = portpicker.get_port_from_port_server(
131                    'portserver', pid=1234)
132                server.sendall.assert_called_once_with(b'1234\n')
133
134        self.assertEqual(port, 42768)
135
136    def testPidDefaultsToOwnPid(self):
137        with ExitStack() as stack:
138            stack.enter_context(
139                mock.patch.object(os, 'getpid', return_value=9876))
140
141            if _winapi:
142                create_file_mock = mock.Mock()
143                create_file_mock.return_value = 0
144                read_file_mock = mock.Mock()
145                write_file_mock = mock.Mock()
146                read_file_mock.return_value = (b'52768\n', 0)
147                stack.enter_context(
148                    mock.patch('_winapi.CreateFile', new=create_file_mock))
149                stack.enter_context(
150                    mock.patch('_winapi.WriteFile', new=write_file_mock))
151                stack.enter_context(
152                    mock.patch('_winapi.ReadFile', new=read_file_mock))
153                port = portpicker.get_port_from_port_server('portserver')
154                write_file_mock.assert_called_once_with(0, b'9876\n')
155            else:
156                server = mock.Mock()
157                server.recv.return_value = b'52768\n'
158                stack.enter_context(
159                    mock.patch.object(socket, 'socket', return_value=server))
160                port = portpicker.get_port_from_port_server('portserver')
161                server.sendall.assert_called_once_with(b'9876\n')
162
163        self.assertEqual(port, 52768)
164
165    @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'})
166    def testReusesPortServerPorts(self):
167        with ExitStack() as stack:
168            if _winapi:
169                read_file_mock = mock.Mock()
170                read_file_mock.side_effect = [
171                    (b'12345\n', 0),
172                    (b'23456\n', 0),
173                    (b'34567\n', 0),
174                ]
175                stack.enter_context(mock.patch('_winapi.CreateFile'))
176                stack.enter_context(mock.patch('_winapi.WriteFile'))
177                stack.enter_context(
178                    mock.patch('_winapi.ReadFile', new=read_file_mock))
179            else:
180                server = mock.Mock()
181                server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n']
182                stack.enter_context(
183                    mock.patch.object(socket, 'socket', return_value=server))
184
185            self.assertEqual(portpicker.pick_unused_port(), 12345)
186            self.assertEqual(portpicker.pick_unused_port(), 23456)
187            portpicker.return_port(12345)
188            self.assertEqual(portpicker.pick_unused_port(), 12345)
189
190    @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''})
191    def testDoesntReuseRandomPorts(self):
192        ports = set()
193        for _ in range(10):
194            try:
195                port = portpicker.pick_unused_port()
196            except portpicker.NoFreePortFoundError:
197                # This sometimes happens when not using portserver. Just
198                # skip to the next attempt.
199                continue
200            ports.add(port)
201            portpicker.return_port(port)
202        self.assertGreater(len(ports), 5)  # Allow some random reuse.
203
204    def testReturnsReservedPorts(self):
205        with mock.patch.object(portpicker, '_pick_unused_port_without_server'):
206            portpicker._pick_unused_port_without_server.side_effect = (
207                Exception('eek!'))
208            # Arbitrary port. In practice you should get this from somewhere
209            # that assigns ports.
210            reserved_port = 28465
211            portpicker.add_reserved_port(reserved_port)
212            ports = set()
213            for _ in range(10):
214                port = portpicker.pick_unused_port()
215                ports.add(port)
216                portpicker.return_port(port)
217            self.assertEqual(len(ports), 1)
218            self.assertEqual(ports.pop(), reserved_port)
219
220    @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''})
221    def testFallsBackToRandomAfterRunningOutOfReservedPorts(self):
222        # Arbitrary port. In practice you should get this from somewhere
223        # that assigns ports.
224        reserved_port = 23456
225        portpicker.add_reserved_port(reserved_port)
226        self.assertEqual(portpicker.pick_unused_port(), reserved_port)
227        self.assertNotEqual(portpicker.pick_unused_port(), reserved_port)
228
229    def testRandomlyChosenPorts(self):
230        # Unless this box is under an overwhelming socket load, this test
231        # will heavily exercise the "pick a port randomly" part of the
232        # port picking code, but may never hit the "OS assigns a port"
233        # code.
234        ports = 0
235        for _ in range(100):
236            try:
237                port = portpicker._pick_unused_port_without_server()
238            except portpicker.NoFreePortFoundError:
239                # Without the portserver, pick_unused_port can sometimes fail
240                # to find a free port. Check that it passes most of the time.
241                continue
242            self.assertTrue(self.IsUnusedTCPPort(port))
243            self.assertTrue(self.IsUnusedUDPPort(port))
244            ports += 1
245        # Getting a port shouldn't have failed very often, even on machines
246        # with a heavy socket load.
247        self.assertGreater(ports, 95)
248
249    def testOSAssignedPorts(self):
250        self.last_assigned_port = None
251
252        def error_for_explicit_ports(port, socket_type, socket_proto):
253            # Only successfully return a port if an OS-assigned port is
254            # requested, or if we're checking that the last OS-assigned port
255            # is unused on the other protocol.
256            if port == 0 or port == self.last_assigned_port:
257                self.last_assigned_port = self._bind(port, socket_type,
258                                                     socket_proto)
259                return self.last_assigned_port
260            else:
261                return None
262
263        with mock.patch.object(portpicker, 'bind', error_for_explicit_ports):
264            # Without server, this can be little flaky, so check that it
265            # passes most of the time.
266            ports = 0
267            for _ in range(100):
268                try:
269                    port = portpicker._pick_unused_port_without_server()
270                except portpicker.NoFreePortFoundError:
271                    continue
272                self.assertTrue(self.IsUnusedTCPPort(port))
273                self.assertTrue(self.IsUnusedUDPPort(port))
274                ports += 1
275            self.assertGreater(ports, 70)
276
277    def pickUnusedPortWithoutServer(self):
278        # Try a few times to pick a port, to avoid flakiness and to make sure
279        # the code path we want was exercised.
280        for _ in range(5):
281            try:
282                port = portpicker._pick_unused_port_without_server()
283            except portpicker.NoFreePortFoundError:
284                continue
285            else:
286                self.assertTrue(self.IsUnusedTCPPort(port))
287                self.assertTrue(self.IsUnusedUDPPort(port))
288                return
289        self.fail("Failed to find a free port")
290
291    def testPickPortsWithoutServer(self):
292        # Test the first part of _pick_unused_port_without_server, which
293        # tries a few random ports and checks is_port_free.
294        self.pickUnusedPortWithoutServer()
295
296        # Now test the second part, the fallback from above, which asks the
297        # OS for a port.
298        def mock_port_free(port):
299            return False
300
301        with mock.patch.object(portpicker, 'is_port_free', mock_port_free):
302            self.pickUnusedPortWithoutServer()
303
304    def checkIsPortFree(self):
305        """This might be flaky unless this test is run with a portserver."""
306        # The port should be free initially.
307        port = portpicker.pick_unused_port()
308        self.assertTrue(portpicker.is_port_free(port))
309
310        cases = [
311            (socket.AF_INET,  socket.SOCK_STREAM, None),
312            (socket.AF_INET6, socket.SOCK_STREAM, 1),
313            (socket.AF_INET,  socket.SOCK_DGRAM,  None),
314            (socket.AF_INET6, socket.SOCK_DGRAM,  1),
315        ]
316
317        # Using v6only=0 on Windows doesn't result in collisions
318        if not _winapi:
319            cases.extend([
320                (socket.AF_INET6, socket.SOCK_STREAM, 0),
321                (socket.AF_INET6, socket.SOCK_DGRAM,  0),
322            ])
323
324        for (sock_family, sock_type, v6only) in cases:
325            # Occupy the port on a subset of possible protocols.
326            try:
327                sock = socket.socket(sock_family, sock_type, 0)
328            except socket.error:
329                print('Kernel does not support sock_family=%d' % sock_family,
330                      file=sys.stderr)
331                # Skip this case, since we cannot occupy a port.
332                continue
333
334            if not hasattr(socket, 'IPPROTO_IPV6'):
335                v6only = None
336
337            if v6only is not None:
338                try:
339                    sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY,
340                                    v6only)
341                except socket.error:
342                    print('Kernel does not support IPV6_V6ONLY=%d' % v6only,
343                          file=sys.stderr)
344                    # Don't care; just proceed with the default.
345
346            # Socket may have been taken in the mean time, so catch the
347            # socket.error with errno set to EADDRINUSE and skip this
348            # attempt.
349            try:
350                sock.bind(('', port))
351            except socket.error as e:
352                if e.errno == errno.EADDRINUSE:
353                    raise portpicker.NoFreePortFoundError
354                raise
355
356            # The port should be busy.
357            self.assertFalse(portpicker.is_port_free(port))
358            sock.close()
359
360            # Now it's free again.
361            self.assertTrue(portpicker.is_port_free(port))
362
363    def testIsPortFree(self):
364        # This can be quite flaky on a busy host, so try a few times.
365        for _ in range(10):
366            try:
367                self.checkIsPortFree()
368            except portpicker.NoFreePortFoundError:
369                pass
370            else:
371                return
372        self.fail("checkPortIsFree failed every time.")
373
374    def testIsPortFreeException(self):
375        port = portpicker.pick_unused_port()
376        with mock.patch.object(socket, 'socket') as mock_sock:
377            mock_sock.side_effect = socket.error('fake socket error', 0)
378            self.assertFalse(portpicker.is_port_free(port))
379
380    def testThatLegacyCapWordsAPIsExist(self):
381        """The original APIs were CapWords style, 1.1 added PEP8 names."""
382        self.assertEqual(portpicker.bind, portpicker.Bind)
383        self.assertEqual(portpicker.is_port_free, portpicker.IsPortFree)
384        self.assertEqual(portpicker.pick_unused_port, portpicker.PickUnusedPort)
385        self.assertEqual(portpicker.get_port_from_port_server,
386                         portpicker.GetPortFromPortServer)
387
388
389if __name__ == '__main__':
390    unittest.main()
391