1#!/usr/bin/python3
2#
3# Copyright 2015 Google Inc. All Rights Reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17"""A server to hand out network ports to applications running on one host.
18
19Typical usage:
20 1) Run one instance of this process on each of your unittest farm hosts.
21 2) Set the PORTSERVER_ADDRESS environment variable in your test runner
22    environment to let the portpicker library know to use a port server
23    rather than attempt to find ports on its own.
24
25$ /path/to/portserver.py &
26$ export PORTSERVER_ADDRESS=@unittest-portserver
27$ # ... launch a bunch of unittest runners using portpicker ...
28"""
29
30import argparse
31import asyncio
32import collections
33import logging
34import signal
35import socket
36import sys
37import psutil
38import subprocess
39from datetime import datetime, timezone, timedelta
40
41log = None  # Initialized to a logging.Logger by _configure_logging().
42
43_PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP),
44           (socket.SOCK_DGRAM, socket.IPPROTO_UDP)]
45
46
47def _get_process_command_line(pid):
48    try:
49        return psutil.Process(pid).cmdline()
50    except psutil.NoSuchProcess:
51        return ''
52
53
54def _get_process_start_time(pid):
55    try:
56        return psutil.Process(pid).create_time()
57    except psutil.NoSuchProcess:
58        return 0.0
59
60
61# TODO: Consider importing portpicker.bind() instead of duplicating the code.
62def _bind(port, socket_type, socket_proto):
63    """Try to bind to a socket of the specified type, protocol, and port.
64
65    For the port to be considered available, the kernel must support at least
66    one of (IPv6, IPv4), and the port must be available on each supported
67    family.
68
69    Args:
70      port: The port number to bind to, or 0 to have the OS pick a free port.
71      socket_type: The type of the socket (ex: socket.SOCK_STREAM).
72      socket_proto: The protocol of the socket (ex: socket.IPPROTO_TCP).
73
74    Returns:
75      The port number on success or None on failure.
76    """
77    got_socket = False
78    for family in (socket.AF_INET6, socket.AF_INET):
79        try:
80            sock = socket.socket(family, socket_type, socket_proto)
81            got_socket = True
82        except socket.error:
83            continue
84        try:
85            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
86            sock.bind(('', port))
87            if socket_type == socket.SOCK_STREAM:
88                sock.listen(1)
89            port = sock.getsockname()[1]
90        except socket.error:
91            return None
92        finally:
93            sock.close()
94    return port if got_socket else None
95
96
97def _is_port_free(port):
98    """Check if specified port is free.
99
100    Args:
101      port: integer, port to check
102    Returns:
103      boolean, whether it is free to use for both TCP and UDP
104    """
105    return _bind(port, *_PROTOS[0]) and _bind(port, *_PROTOS[1])
106
107
108def _should_allocate_port(pid):
109    """Determine if we should allocate a port for use by the given process id."""
110    if pid <= 0:
111        log.info('Not allocating a port to invalid pid')
112        return False
113    if pid == 1:
114        # The client probably meant to send us its parent pid but
115        # had been reparented to init.
116        log.info('Not allocating a port to init.')
117        return False
118
119    if not psutil.pid_exists(pid):
120        log.info('Not allocating a port to a non-existent process')
121        return False
122    return True
123
124
125async def _start_windows_server(client_connected_cb, path):
126    """Start the server on Windows using named pipes."""
127    def protocol_factory():
128        stream_reader = asyncio.StreamReader()
129        stream_reader_protocol = asyncio.StreamReaderProtocol(
130            stream_reader, client_connected_cb)
131        return stream_reader_protocol
132
133    loop = asyncio.get_event_loop()
134    server, *_ = await loop.start_serving_pipe(protocol_factory, address=path)
135
136    return server
137
138
139class _PortInfo(object):
140    """Container class for information about a given port assignment.
141
142    Attributes:
143      port: integer port number
144      pid: integer process id or 0 if unassigned.
145      start_time: Time in seconds since the epoch that the process started.
146    """
147
148    __slots__ = ('port', 'pid', 'start_time')
149
150    def __init__(self, port):
151        self.port = port
152        self.pid = 0
153        self.start_time = 0.0
154
155
156class _PortPool(object):
157    """Manage available ports for processes.
158
159    Ports are reclaimed when the reserving process exits and the reserved port
160    is no longer in use.  Only ports which are free for both TCP and UDP will be
161    handed out.  It is easier to not differentiate between protocols.
162
163    The pool must be pre-seeded with add_port_to_free_pool() calls
164    after which get_port_for_process() will allocate and reclaim ports.
165    The len() of a _PortPool returns the total number of ports being managed.
166
167    Attributes:
168      ports_checked_for_last_request: The number of ports examined in order to
169          return from the most recent get_port_for_process() request.  A high
170          number here likely means the number of available ports with no active
171          process using them is getting low.
172    """
173
174    def __init__(self):
175        self._port_queue = collections.deque()
176        self.ports_checked_for_last_request = 0
177
178    def num_ports(self):
179        return len(self._port_queue)
180
181    def get_port_for_process(self, pid):
182        """Allocates and returns port for pid or 0 if none could be allocated."""
183        if not self._port_queue:
184            raise RuntimeError('No ports being managed.')
185
186        # Avoid an infinite loop if all ports are currently assigned.
187        check_count = 0
188        max_ports_to_test = len(self._port_queue)
189        while check_count < max_ports_to_test:
190            # Get the next candidate port and move it to the back of the queue.
191            candidate = self._port_queue.pop()
192            self._port_queue.appendleft(candidate)
193            check_count += 1
194            if (candidate.start_time == 0.0 or
195                candidate.start_time != _get_process_start_time(candidate.pid)):
196                if _is_port_free(candidate.port):
197                    candidate.pid = pid
198                    candidate.start_time = _get_process_start_time(pid)
199                    if not candidate.start_time:
200                        log.info("Can't read start time for pid %d.", pid)
201                    self.ports_checked_for_last_request = check_count
202                    return candidate.port
203                else:
204                    log.info(
205                        'Port %d unexpectedly in use, last owning pid %d.',
206                        candidate.port, candidate.pid)
207
208        log.info('All ports in use.')
209        self.ports_checked_for_last_request = check_count
210        return 0
211
212    def add_port_to_free_pool(self, port):
213        """Add a new port to the free pool for allocation."""
214        if port < 1 or port > 65535:
215            raise ValueError(
216                'Port must be in the [1, 65535] range, not %d.' % port)
217        port_info = _PortInfo(port=port)
218        self._port_queue.append(port_info)
219
220
221class _PortServerRequestHandler(object):
222    """A class to handle port allocation and status requests.
223
224    Allocates ports to process ids via the dead simple port server protocol
225    when the handle_port_request asyncio.coroutine handler has been registered.
226    Statistics can be logged using the dump_stats method.
227    """
228
229    def __init__(self, ports_to_serve):
230        """Initialize a new port server.
231
232        Args:
233          ports_to_serve: A sequence of unique port numbers to test and offer
234              up to clients.
235        """
236        self._port_pool = _PortPool()
237        self._total_allocations = 0
238        self._denied_allocations = 0
239        self._client_request_errors = 0
240        for port in ports_to_serve:
241            self._port_pool.add_port_to_free_pool(port)
242
243    async def handle_port_request(self, reader, writer):
244        client_data = await reader.read(100)
245        self._handle_port_request(client_data, writer)
246        writer.close()
247
248    def _handle_port_request(self, client_data, writer):
249        """Given a port request body, parse it and respond appropriately.
250
251        Args:
252          client_data: The request bytes from the client.
253          writer: The asyncio Writer for the response to be written to.
254        """
255        try:
256            if len(client_data) > 20:
257                raise ValueError('More than 20 characters in "pid".')
258            pid = int(client_data)
259        except ValueError as error:
260            self._client_request_errors += 1
261            log.warning('Could not parse request: %s', error)
262            return
263
264        log.info('Request on behalf of pid %d.', pid)
265        log.info('cmdline: %s', _get_process_command_line(pid))
266
267        if not _should_allocate_port(pid):
268            self._denied_allocations += 1
269            return
270
271        port = self._port_pool.get_port_for_process(pid)
272        if port > 0:
273            self._total_allocations += 1
274            writer.write('{:d}\n'.format(port).encode('utf-8'))
275            log.debug('Allocated port %d to pid %d', port, pid)
276        else:
277            self._denied_allocations += 1
278
279    def dump_stats(self):
280        """Logs statistics of our operation."""
281        log.info('Dumping statistics:')
282        stats = []
283        stats.append(
284            'client-request-errors {}'.format(self._client_request_errors))
285        stats.append('denied-allocations {}'.format(self._denied_allocations))
286        stats.append('num-ports-managed {}'.format(self._port_pool.num_ports()))
287        stats.append('num-ports-checked-for-last-request {}'.format(
288            self._port_pool.ports_checked_for_last_request))
289        stats.append('total-allocations {}'.format(self._total_allocations))
290        for stat in stats:
291            log.info(stat)
292
293
294def _parse_command_line():
295    """Configure and parse our command line flags."""
296    parser = argparse.ArgumentParser()
297    parser.add_argument(
298        '--portserver_static_pool',
299        type=str,
300        default='15000-24999',
301        help='Comma separated N-P Range(s) of ports to manage (inclusive).')
302    parser.add_argument(
303        '--portserver_address',
304        '--portserver_unix_socket_address', # Alias to be backward compatible
305        type=str,
306        default='@unittest-portserver',
307        help='Address of AF_UNIX socket on which to listen on Unix (first @ is '
308             'a NUL) or the name of the pipe on Windows (first @ is the '
309             r'\\.\pipe\ prefix).')
310    parser.add_argument('--verbose',
311                        action='store_true',
312                        default=False,
313                        help='Enable verbose messages.')
314    parser.add_argument('--debug',
315                        action='store_true',
316                        default=False,
317                        help='Enable full debug messages.')
318    return parser.parse_args(sys.argv[1:])
319
320
321def _parse_port_ranges(pool_str):
322    """Given a 'N-P,X-Y' description of port ranges, return a set of ints."""
323    ports = set()
324    for range_str in pool_str.split(','):
325        try:
326            a, b = range_str.split('-', 1)
327            start, end = int(a), int(b)
328        except ValueError:
329            log.error('Ignoring unparsable port range %r.', range_str)
330            continue
331        if start < 1 or end > 65535:
332            log.error('Ignoring out of bounds port range %r.', range_str)
333            continue
334        ports.update(set(range(start, end + 1)))
335    return ports
336
337
338def _configure_logging(verbose=False, debug=False):
339    """Configure the log global, message format, and verbosity settings."""
340    overall_level = logging.DEBUG if debug else logging.INFO
341    logging.basicConfig(
342        format=('{levelname[0]}{asctime}.{msecs:03.0f} {thread} '
343                '{filename}:{lineno}] {message}'),
344        datefmt='%m%d %H:%M:%S',
345        style='{',
346        level=overall_level)
347    global log
348    log = logging.getLogger('portserver')
349    # The verbosity controls our loggers logging level, not the global
350    # one above. This avoids debug messages from libraries such as asyncio.
351    log.setLevel(logging.DEBUG if verbose else overall_level)
352
353
354def main():
355    config = _parse_command_line()
356    if config.debug:
357        # Equivalent of PYTHONASYNCIODEBUG=1 in 3.4; pylint: disable=protected-access
358        asyncio.tasks._DEBUG = True
359    _configure_logging(verbose=config.verbose, debug=config.debug)
360    ports_to_serve = _parse_port_ranges(config.portserver_static_pool)
361    if not ports_to_serve:
362        log.error('No ports.  Invalid port ranges in --portserver_static_pool?')
363        sys.exit(1)
364
365    request_handler = _PortServerRequestHandler(ports_to_serve)
366
367    if sys.platform == 'win32':
368        asyncio.set_event_loop(asyncio.ProactorEventLoop())
369
370    event_loop = asyncio.get_event_loop()
371
372    if sys.platform == 'win32':
373        # On Windows, we need to periodically pause the loop to allow the user
374        # to send a break signal (e.g. ctrl+c)
375        def listen_for_signal():
376            event_loop.call_later(0.5, listen_for_signal)
377
378        event_loop.call_later(0.5, listen_for_signal)
379
380        coro = _start_windows_server(
381            request_handler.handle_port_request,
382            path=config.portserver_address.replace('@', '\\\\.\\pipe\\', 1))
383    else:
384        event_loop.add_signal_handler(
385            signal.SIGUSR1, request_handler.dump_stats) # pylint: disable=no-member
386
387        old_py_loop = {'loop': event_loop} if sys.version_info < (3, 10) else {}
388        coro = asyncio.start_unix_server(
389            request_handler.handle_port_request,
390            path=config.portserver_address.replace('@', '\0', 1),
391            **old_py_loop)
392
393    server_address = config.portserver_address
394
395    server = event_loop.run_until_complete(coro)
396    log.info('Serving on %s', server_address)
397    try:
398        event_loop.run_forever()
399    except KeyboardInterrupt:
400        log.info('Stopping due to ^C.')
401
402    server.close()
403
404    if sys.platform != 'win32':
405        # PipeServer doesn't have a wait_closed() function
406        event_loop.run_until_complete(server.wait_closed())
407        event_loop.remove_signal_handler(signal.SIGUSR1) # pylint: disable=no-member
408
409    event_loop.close()
410    request_handler.dump_stats()
411    log.info('Goodbye.')
412
413
414if __name__ == '__main__':
415    main()
416