1"""Utilities shared by tests."""
2
3import asyncio
4import collections
5import contextlib
6import io
7import logging
8import os
9import re
10import selectors
11import socket
12import socketserver
13import sys
14import tempfile
15import threading
16import time
17import unittest
18import weakref
19
20from unittest import mock
21
22from http.server import HTTPServer
23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
24
25try:
26    import ssl
27except ImportError:  # pragma: no cover
28    ssl = None
29
30from asyncio import base_events
31from asyncio import events
32from asyncio import format_helpers
33from asyncio import futures
34from asyncio import tasks
35from asyncio.log import logger
36from test import support
37from test.support import threading_helper
38
39
40def data_file(filename):
41    if hasattr(support, 'TEST_HOME_DIR'):
42        fullname = os.path.join(support.TEST_HOME_DIR, filename)
43        if os.path.isfile(fullname):
44            return fullname
45    fullname = os.path.join(os.path.dirname(__file__), '..', filename)
46    if os.path.isfile(fullname):
47        return fullname
48    raise FileNotFoundError(filename)
49
50
51ONLYCERT = data_file('ssl_cert.pem')
52ONLYKEY = data_file('ssl_key.pem')
53SIGNED_CERTFILE = data_file('keycert3.pem')
54SIGNING_CA = data_file('pycacert.pem')
55PEERCERT = {
56    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
57    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
58    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
59    'issuer': ((('countryName', 'XY'),),
60            (('organizationName', 'Python Software Foundation CA'),),
61            (('commonName', 'our-ca-server'),)),
62    'notAfter': 'Oct 28 14:23:16 2037 GMT',
63    'notBefore': 'Aug 29 14:23:16 2018 GMT',
64    'serialNumber': 'CB2D80995A69525C',
65    'subject': ((('countryName', 'XY'),),
66             (('localityName', 'Castle Anthrax'),),
67             (('organizationName', 'Python Software Foundation'),),
68             (('commonName', 'localhost'),)),
69    'subjectAltName': (('DNS', 'localhost'),),
70    'version': 3
71}
72
73
74def simple_server_sslcontext():
75    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
76    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
77    server_context.check_hostname = False
78    server_context.verify_mode = ssl.CERT_NONE
79    return server_context
80
81
82def simple_client_sslcontext(*, disable_verify=True):
83    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
84    client_context.check_hostname = False
85    if disable_verify:
86        client_context.verify_mode = ssl.CERT_NONE
87    return client_context
88
89
90def dummy_ssl_context():
91    if ssl is None:
92        return None
93    else:
94        return simple_client_sslcontext(disable_verify=True)
95
96
97def run_briefly(loop):
98    async def once():
99        pass
100    gen = once()
101    t = loop.create_task(gen)
102    # Don't log a warning if the task is not done after run_until_complete().
103    # It occurs if the loop is stopped or if a task raises a BaseException.
104    t._log_destroy_pending = False
105    try:
106        loop.run_until_complete(t)
107    finally:
108        gen.close()
109
110
111def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
112    deadline = time.monotonic() + timeout
113    while not pred():
114        if timeout is not None:
115            timeout = deadline - time.monotonic()
116            if timeout <= 0:
117                raise futures.TimeoutError()
118        loop.run_until_complete(tasks.sleep(0.001))
119
120
121def run_once(loop):
122    """Legacy API to run once through the event loop.
123
124    This is the recommended pattern for test code.  It will poll the
125    selector once and run all callbacks scheduled in response to I/O
126    events.
127    """
128    loop.call_soon(loop.stop)
129    loop.run_forever()
130
131
132class SilentWSGIRequestHandler(WSGIRequestHandler):
133
134    def get_stderr(self):
135        return io.StringIO()
136
137    def log_message(self, format, *args):
138        pass
139
140
141class SilentWSGIServer(WSGIServer):
142
143    request_timeout = support.LOOPBACK_TIMEOUT
144
145    def get_request(self):
146        request, client_addr = super().get_request()
147        request.settimeout(self.request_timeout)
148        return request, client_addr
149
150    def handle_error(self, request, client_address):
151        pass
152
153
154class SSLWSGIServerMixin:
155
156    def finish_request(self, request, client_address):
157        # The relative location of our test directory (which
158        # contains the ssl key and certificate files) differs
159        # between the stdlib and stand-alone asyncio.
160        # Prefer our own if we can find it.
161        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
162        context.load_cert_chain(ONLYCERT, ONLYKEY)
163
164        ssock = context.wrap_socket(request, server_side=True)
165        try:
166            self.RequestHandlerClass(ssock, client_address, self)
167            ssock.close()
168        except OSError:
169            # maybe socket has been closed by peer
170            pass
171
172
173class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
174    pass
175
176
177def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
178
179    def loop(environ):
180        size = int(environ['CONTENT_LENGTH'])
181        while size:
182            data = environ['wsgi.input'].read(min(size, 0x10000))
183            yield data
184            size -= len(data)
185
186    def app(environ, start_response):
187        status = '200 OK'
188        headers = [('Content-type', 'text/plain')]
189        start_response(status, headers)
190        if environ['PATH_INFO'] == '/loop':
191            return loop(environ)
192        else:
193            return [b'Test message']
194
195    # Run the test WSGI server in a separate thread in order not to
196    # interfere with event handling in the main thread
197    server_class = server_ssl_cls if use_ssl else server_cls
198    httpd = server_class(address, SilentWSGIRequestHandler)
199    httpd.set_app(app)
200    httpd.address = httpd.server_address
201    server_thread = threading.Thread(
202        target=lambda: httpd.serve_forever(poll_interval=0.05))
203    server_thread.start()
204    try:
205        yield httpd
206    finally:
207        httpd.shutdown()
208        httpd.server_close()
209        server_thread.join()
210
211
212if hasattr(socket, 'AF_UNIX'):
213
214    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
215
216        def server_bind(self):
217            socketserver.UnixStreamServer.server_bind(self)
218            self.server_name = '127.0.0.1'
219            self.server_port = 80
220
221
222    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
223
224        request_timeout = support.LOOPBACK_TIMEOUT
225
226        def server_bind(self):
227            UnixHTTPServer.server_bind(self)
228            self.setup_environ()
229
230        def get_request(self):
231            request, client_addr = super().get_request()
232            request.settimeout(self.request_timeout)
233            # Code in the stdlib expects that get_request
234            # will return a socket and a tuple (host, port).
235            # However, this isn't true for UNIX sockets,
236            # as the second return value will be a path;
237            # hence we return some fake data sufficient
238            # to get the tests going
239            return request, ('127.0.0.1', '')
240
241
242    class SilentUnixWSGIServer(UnixWSGIServer):
243
244        def handle_error(self, request, client_address):
245            pass
246
247
248    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
249        pass
250
251
252    def gen_unix_socket_path():
253        with tempfile.NamedTemporaryFile() as file:
254            return file.name
255
256
257    @contextlib.contextmanager
258    def unix_socket_path():
259        path = gen_unix_socket_path()
260        try:
261            yield path
262        finally:
263            try:
264                os.unlink(path)
265            except OSError:
266                pass
267
268
269    @contextlib.contextmanager
270    def run_test_unix_server(*, use_ssl=False):
271        with unix_socket_path() as path:
272            yield from _run_test_server(address=path, use_ssl=use_ssl,
273                                        server_cls=SilentUnixWSGIServer,
274                                        server_ssl_cls=UnixSSLWSGIServer)
275
276
277@contextlib.contextmanager
278def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
279    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
280                                server_cls=SilentWSGIServer,
281                                server_ssl_cls=SSLWSGIServer)
282
283
284def echo_datagrams(sock):
285    while True:
286        data, addr = sock.recvfrom(4096)
287        if data == b'STOP':
288            sock.close()
289            break
290        else:
291            sock.sendto(data, addr)
292
293
294@contextlib.contextmanager
295def run_udp_echo_server(*, host='127.0.0.1', port=0):
296    addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
297    family, type, proto, _, sockaddr = addr_info[0]
298    sock = socket.socket(family, type, proto)
299    sock.bind((host, port))
300    thread = threading.Thread(target=lambda: echo_datagrams(sock))
301    thread.start()
302    try:
303        yield sock.getsockname()
304    finally:
305        sock.sendto(b'STOP', sock.getsockname())
306        thread.join()
307
308
309def make_test_protocol(base):
310    dct = {}
311    for name in dir(base):
312        if name.startswith('__') and name.endswith('__'):
313            # skip magic names
314            continue
315        dct[name] = MockCallback(return_value=None)
316    return type('TestProtocol', (base,) + base.__bases__, dct)()
317
318
319class TestSelector(selectors.BaseSelector):
320
321    def __init__(self):
322        self.keys = {}
323
324    def register(self, fileobj, events, data=None):
325        key = selectors.SelectorKey(fileobj, 0, events, data)
326        self.keys[fileobj] = key
327        return key
328
329    def unregister(self, fileobj):
330        return self.keys.pop(fileobj)
331
332    def select(self, timeout):
333        return []
334
335    def get_map(self):
336        return self.keys
337
338
339class TestLoop(base_events.BaseEventLoop):
340    """Loop for unittests.
341
342    It manages self time directly.
343    If something scheduled to be executed later then
344    on next loop iteration after all ready handlers done
345    generator passed to __init__ is calling.
346
347    Generator should be like this:
348
349        def gen():
350            ...
351            when = yield ...
352            ... = yield time_advance
353
354    Value returned by yield is absolute time of next scheduled handler.
355    Value passed to yield is time advance to move loop's time forward.
356    """
357
358    def __init__(self, gen=None):
359        super().__init__()
360
361        if gen is None:
362            def gen():
363                yield
364            self._check_on_close = False
365        else:
366            self._check_on_close = True
367
368        self._gen = gen()
369        next(self._gen)
370        self._time = 0
371        self._clock_resolution = 1e-9
372        self._timers = []
373        self._selector = TestSelector()
374
375        self.readers = {}
376        self.writers = {}
377        self.reset_counters()
378
379        self._transports = weakref.WeakValueDictionary()
380
381    def time(self):
382        return self._time
383
384    def advance_time(self, advance):
385        """Move test time forward."""
386        if advance:
387            self._time += advance
388
389    def close(self):
390        super().close()
391        if self._check_on_close:
392            try:
393                self._gen.send(0)
394            except StopIteration:
395                pass
396            else:  # pragma: no cover
397                raise AssertionError("Time generator is not finished")
398
399    def _add_reader(self, fd, callback, *args):
400        self.readers[fd] = events.Handle(callback, args, self, None)
401
402    def _remove_reader(self, fd):
403        self.remove_reader_count[fd] += 1
404        if fd in self.readers:
405            del self.readers[fd]
406            return True
407        else:
408            return False
409
410    def assert_reader(self, fd, callback, *args):
411        if fd not in self.readers:
412            raise AssertionError(f'fd {fd} is not registered')
413        handle = self.readers[fd]
414        if handle._callback != callback:
415            raise AssertionError(
416                f'unexpected callback: {handle._callback} != {callback}')
417        if handle._args != args:
418            raise AssertionError(
419                f'unexpected callback args: {handle._args} != {args}')
420
421    def assert_no_reader(self, fd):
422        if fd in self.readers:
423            raise AssertionError(f'fd {fd} is registered')
424
425    def _add_writer(self, fd, callback, *args):
426        self.writers[fd] = events.Handle(callback, args, self, None)
427
428    def _remove_writer(self, fd):
429        self.remove_writer_count[fd] += 1
430        if fd in self.writers:
431            del self.writers[fd]
432            return True
433        else:
434            return False
435
436    def assert_writer(self, fd, callback, *args):
437        if fd not in self.writers:
438            raise AssertionError(f'fd {fd} is not registered')
439        handle = self.writers[fd]
440        if handle._callback != callback:
441            raise AssertionError(f'{handle._callback!r} != {callback!r}')
442        if handle._args != args:
443            raise AssertionError(f'{handle._args!r} != {args!r}')
444
445    def _ensure_fd_no_transport(self, fd):
446        if not isinstance(fd, int):
447            try:
448                fd = int(fd.fileno())
449            except (AttributeError, TypeError, ValueError):
450                # This code matches selectors._fileobj_to_fd function.
451                raise ValueError("Invalid file object: "
452                                 "{!r}".format(fd)) from None
453        try:
454            transport = self._transports[fd]
455        except KeyError:
456            pass
457        else:
458            raise RuntimeError(
459                'File descriptor {!r} is used by transport {!r}'.format(
460                    fd, transport))
461
462    def add_reader(self, fd, callback, *args):
463        """Add a reader callback."""
464        self._ensure_fd_no_transport(fd)
465        return self._add_reader(fd, callback, *args)
466
467    def remove_reader(self, fd):
468        """Remove a reader callback."""
469        self._ensure_fd_no_transport(fd)
470        return self._remove_reader(fd)
471
472    def add_writer(self, fd, callback, *args):
473        """Add a writer callback.."""
474        self._ensure_fd_no_transport(fd)
475        return self._add_writer(fd, callback, *args)
476
477    def remove_writer(self, fd):
478        """Remove a writer callback."""
479        self._ensure_fd_no_transport(fd)
480        return self._remove_writer(fd)
481
482    def reset_counters(self):
483        self.remove_reader_count = collections.defaultdict(int)
484        self.remove_writer_count = collections.defaultdict(int)
485
486    def _run_once(self):
487        super()._run_once()
488        for when in self._timers:
489            advance = self._gen.send(when)
490            self.advance_time(advance)
491        self._timers = []
492
493    def call_at(self, when, callback, *args, context=None):
494        self._timers.append(when)
495        return super().call_at(when, callback, *args, context=context)
496
497    def _process_events(self, event_list):
498        return
499
500    def _write_to_self(self):
501        pass
502
503
504def MockCallback(**kwargs):
505    return mock.Mock(spec=['__call__'], **kwargs)
506
507
508class MockPattern(str):
509    """A regex based str with a fuzzy __eq__.
510
511    Use this helper with 'mock.assert_called_with', or anywhere
512    where a regex comparison between strings is needed.
513
514    For instance:
515       mock_call.assert_called_with(MockPattern('spam.*ham'))
516    """
517    def __eq__(self, other):
518        return bool(re.search(str(self), other, re.S))
519
520
521class MockInstanceOf:
522    def __init__(self, type):
523        self._type = type
524
525    def __eq__(self, other):
526        return isinstance(other, self._type)
527
528
529def get_function_source(func):
530    source = format_helpers._get_function_source(func)
531    if source is None:
532        raise ValueError("unable to get the source of %r" % (func,))
533    return source
534
535
536class TestCase(unittest.TestCase):
537    @staticmethod
538    def close_loop(loop):
539        if loop._default_executor is not None:
540            if not loop.is_closed():
541                loop.run_until_complete(loop.shutdown_default_executor())
542            else:
543                loop._default_executor.shutdown(wait=True)
544        loop.close()
545        policy = support.maybe_get_event_loop_policy()
546        if policy is not None:
547            try:
548                watcher = policy.get_child_watcher()
549            except NotImplementedError:
550                # watcher is not implemented by EventLoopPolicy, e.g. Windows
551                pass
552            else:
553                if isinstance(watcher, asyncio.ThreadedChildWatcher):
554                    threads = list(watcher._threads.values())
555                    for thread in threads:
556                        thread.join()
557
558    def set_event_loop(self, loop, *, cleanup=True):
559        if loop is None:
560            raise AssertionError('loop is None')
561        # ensure that the event loop is passed explicitly in asyncio
562        events.set_event_loop(None)
563        if cleanup:
564            self.addCleanup(self.close_loop, loop)
565
566    def new_test_loop(self, gen=None):
567        loop = TestLoop(gen)
568        self.set_event_loop(loop)
569        return loop
570
571    def setUp(self):
572        self._thread_cleanup = threading_helper.threading_setup()
573
574    def tearDown(self):
575        events.set_event_loop(None)
576
577        # Detect CPython bug #23353: ensure that yield/yield-from is not used
578        # in an except block of a generator
579        self.assertEqual(sys.exc_info(), (None, None, None))
580
581        self.doCleanups()
582        threading_helper.threading_cleanup(*self._thread_cleanup)
583        support.reap_children()
584
585
586@contextlib.contextmanager
587def disable_logger():
588    """Context manager to disable asyncio logger.
589
590    For example, it can be used to ignore warnings in debug mode.
591    """
592    old_level = logger.level
593    try:
594        logger.setLevel(logging.CRITICAL+1)
595        yield
596    finally:
597        logger.setLevel(old_level)
598
599
600def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
601                            family=socket.AF_INET):
602    """Create a mock of a non-blocking socket."""
603    sock = mock.MagicMock(socket.socket)
604    sock.proto = proto
605    sock.type = type
606    sock.family = family
607    sock.gettimeout.return_value = 0.0
608    return sock
609