1"""Utilities shared by tests.""" 2 3import asyncio 4import collections 5import contextlib 6import io 7import logging 8import os 9import re 10import selectors 11import socket 12import socketserver 13import sys 14import tempfile 15import threading 16import time 17import unittest 18import weakref 19 20from unittest import mock 21 22from http.server import HTTPServer 23from wsgiref.simple_server import WSGIRequestHandler, WSGIServer 24 25try: 26 import ssl 27except ImportError: # pragma: no cover 28 ssl = None 29 30from asyncio import base_events 31from asyncio import events 32from asyncio import format_helpers 33from asyncio import futures 34from asyncio import tasks 35from asyncio.log import logger 36from test import support 37from test.support import threading_helper 38 39 40def data_file(filename): 41 if hasattr(support, 'TEST_HOME_DIR'): 42 fullname = os.path.join(support.TEST_HOME_DIR, filename) 43 if os.path.isfile(fullname): 44 return fullname 45 fullname = os.path.join(os.path.dirname(__file__), '..', filename) 46 if os.path.isfile(fullname): 47 return fullname 48 raise FileNotFoundError(filename) 49 50 51ONLYCERT = data_file('ssl_cert.pem') 52ONLYKEY = data_file('ssl_key.pem') 53SIGNED_CERTFILE = data_file('keycert3.pem') 54SIGNING_CA = data_file('pycacert.pem') 55PEERCERT = { 56 'OCSP': ('http://testca.pythontest.net/testca/ocsp/',), 57 'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',), 58 'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',), 59 'issuer': ((('countryName', 'XY'),), 60 (('organizationName', 'Python Software Foundation CA'),), 61 (('commonName', 'our-ca-server'),)), 62 'notAfter': 'Oct 28 14:23:16 2037 GMT', 63 'notBefore': 'Aug 29 14:23:16 2018 GMT', 64 'serialNumber': 'CB2D80995A69525C', 65 'subject': ((('countryName', 'XY'),), 66 (('localityName', 'Castle Anthrax'),), 67 (('organizationName', 'Python Software Foundation'),), 68 (('commonName', 'localhost'),)), 69 'subjectAltName': (('DNS', 'localhost'),), 70 'version': 3 71} 72 73 74def simple_server_sslcontext(): 75 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 76 server_context.load_cert_chain(ONLYCERT, ONLYKEY) 77 server_context.check_hostname = False 78 server_context.verify_mode = ssl.CERT_NONE 79 return server_context 80 81 82def simple_client_sslcontext(*, disable_verify=True): 83 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 84 client_context.check_hostname = False 85 if disable_verify: 86 client_context.verify_mode = ssl.CERT_NONE 87 return client_context 88 89 90def dummy_ssl_context(): 91 if ssl is None: 92 return None 93 else: 94 return simple_client_sslcontext(disable_verify=True) 95 96 97def run_briefly(loop): 98 async def once(): 99 pass 100 gen = once() 101 t = loop.create_task(gen) 102 # Don't log a warning if the task is not done after run_until_complete(). 103 # It occurs if the loop is stopped or if a task raises a BaseException. 104 t._log_destroy_pending = False 105 try: 106 loop.run_until_complete(t) 107 finally: 108 gen.close() 109 110 111def run_until(loop, pred, timeout=support.SHORT_TIMEOUT): 112 deadline = time.monotonic() + timeout 113 while not pred(): 114 if timeout is not None: 115 timeout = deadline - time.monotonic() 116 if timeout <= 0: 117 raise futures.TimeoutError() 118 loop.run_until_complete(tasks.sleep(0.001)) 119 120 121def run_once(loop): 122 """Legacy API to run once through the event loop. 123 124 This is the recommended pattern for test code. It will poll the 125 selector once and run all callbacks scheduled in response to I/O 126 events. 127 """ 128 loop.call_soon(loop.stop) 129 loop.run_forever() 130 131 132class SilentWSGIRequestHandler(WSGIRequestHandler): 133 134 def get_stderr(self): 135 return io.StringIO() 136 137 def log_message(self, format, *args): 138 pass 139 140 141class SilentWSGIServer(WSGIServer): 142 143 request_timeout = support.LOOPBACK_TIMEOUT 144 145 def get_request(self): 146 request, client_addr = super().get_request() 147 request.settimeout(self.request_timeout) 148 return request, client_addr 149 150 def handle_error(self, request, client_address): 151 pass 152 153 154class SSLWSGIServerMixin: 155 156 def finish_request(self, request, client_address): 157 # The relative location of our test directory (which 158 # contains the ssl key and certificate files) differs 159 # between the stdlib and stand-alone asyncio. 160 # Prefer our own if we can find it. 161 context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 162 context.load_cert_chain(ONLYCERT, ONLYKEY) 163 164 ssock = context.wrap_socket(request, server_side=True) 165 try: 166 self.RequestHandlerClass(ssock, client_address, self) 167 ssock.close() 168 except OSError: 169 # maybe socket has been closed by peer 170 pass 171 172 173class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): 174 pass 175 176 177def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): 178 179 def loop(environ): 180 size = int(environ['CONTENT_LENGTH']) 181 while size: 182 data = environ['wsgi.input'].read(min(size, 0x10000)) 183 yield data 184 size -= len(data) 185 186 def app(environ, start_response): 187 status = '200 OK' 188 headers = [('Content-type', 'text/plain')] 189 start_response(status, headers) 190 if environ['PATH_INFO'] == '/loop': 191 return loop(environ) 192 else: 193 return [b'Test message'] 194 195 # Run the test WSGI server in a separate thread in order not to 196 # interfere with event handling in the main thread 197 server_class = server_ssl_cls if use_ssl else server_cls 198 httpd = server_class(address, SilentWSGIRequestHandler) 199 httpd.set_app(app) 200 httpd.address = httpd.server_address 201 server_thread = threading.Thread( 202 target=lambda: httpd.serve_forever(poll_interval=0.05)) 203 server_thread.start() 204 try: 205 yield httpd 206 finally: 207 httpd.shutdown() 208 httpd.server_close() 209 server_thread.join() 210 211 212if hasattr(socket, 'AF_UNIX'): 213 214 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): 215 216 def server_bind(self): 217 socketserver.UnixStreamServer.server_bind(self) 218 self.server_name = '127.0.0.1' 219 self.server_port = 80 220 221 222 class UnixWSGIServer(UnixHTTPServer, WSGIServer): 223 224 request_timeout = support.LOOPBACK_TIMEOUT 225 226 def server_bind(self): 227 UnixHTTPServer.server_bind(self) 228 self.setup_environ() 229 230 def get_request(self): 231 request, client_addr = super().get_request() 232 request.settimeout(self.request_timeout) 233 # Code in the stdlib expects that get_request 234 # will return a socket and a tuple (host, port). 235 # However, this isn't true for UNIX sockets, 236 # as the second return value will be a path; 237 # hence we return some fake data sufficient 238 # to get the tests going 239 return request, ('127.0.0.1', '') 240 241 242 class SilentUnixWSGIServer(UnixWSGIServer): 243 244 def handle_error(self, request, client_address): 245 pass 246 247 248 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): 249 pass 250 251 252 def gen_unix_socket_path(): 253 with tempfile.NamedTemporaryFile() as file: 254 return file.name 255 256 257 @contextlib.contextmanager 258 def unix_socket_path(): 259 path = gen_unix_socket_path() 260 try: 261 yield path 262 finally: 263 try: 264 os.unlink(path) 265 except OSError: 266 pass 267 268 269 @contextlib.contextmanager 270 def run_test_unix_server(*, use_ssl=False): 271 with unix_socket_path() as path: 272 yield from _run_test_server(address=path, use_ssl=use_ssl, 273 server_cls=SilentUnixWSGIServer, 274 server_ssl_cls=UnixSSLWSGIServer) 275 276 277@contextlib.contextmanager 278def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): 279 yield from _run_test_server(address=(host, port), use_ssl=use_ssl, 280 server_cls=SilentWSGIServer, 281 server_ssl_cls=SSLWSGIServer) 282 283 284def echo_datagrams(sock): 285 while True: 286 data, addr = sock.recvfrom(4096) 287 if data == b'STOP': 288 sock.close() 289 break 290 else: 291 sock.sendto(data, addr) 292 293 294@contextlib.contextmanager 295def run_udp_echo_server(*, host='127.0.0.1', port=0): 296 addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM) 297 family, type, proto, _, sockaddr = addr_info[0] 298 sock = socket.socket(family, type, proto) 299 sock.bind((host, port)) 300 thread = threading.Thread(target=lambda: echo_datagrams(sock)) 301 thread.start() 302 try: 303 yield sock.getsockname() 304 finally: 305 sock.sendto(b'STOP', sock.getsockname()) 306 thread.join() 307 308 309def make_test_protocol(base): 310 dct = {} 311 for name in dir(base): 312 if name.startswith('__') and name.endswith('__'): 313 # skip magic names 314 continue 315 dct[name] = MockCallback(return_value=None) 316 return type('TestProtocol', (base,) + base.__bases__, dct)() 317 318 319class TestSelector(selectors.BaseSelector): 320 321 def __init__(self): 322 self.keys = {} 323 324 def register(self, fileobj, events, data=None): 325 key = selectors.SelectorKey(fileobj, 0, events, data) 326 self.keys[fileobj] = key 327 return key 328 329 def unregister(self, fileobj): 330 return self.keys.pop(fileobj) 331 332 def select(self, timeout): 333 return [] 334 335 def get_map(self): 336 return self.keys 337 338 339class TestLoop(base_events.BaseEventLoop): 340 """Loop for unittests. 341 342 It manages self time directly. 343 If something scheduled to be executed later then 344 on next loop iteration after all ready handlers done 345 generator passed to __init__ is calling. 346 347 Generator should be like this: 348 349 def gen(): 350 ... 351 when = yield ... 352 ... = yield time_advance 353 354 Value returned by yield is absolute time of next scheduled handler. 355 Value passed to yield is time advance to move loop's time forward. 356 """ 357 358 def __init__(self, gen=None): 359 super().__init__() 360 361 if gen is None: 362 def gen(): 363 yield 364 self._check_on_close = False 365 else: 366 self._check_on_close = True 367 368 self._gen = gen() 369 next(self._gen) 370 self._time = 0 371 self._clock_resolution = 1e-9 372 self._timers = [] 373 self._selector = TestSelector() 374 375 self.readers = {} 376 self.writers = {} 377 self.reset_counters() 378 379 self._transports = weakref.WeakValueDictionary() 380 381 def time(self): 382 return self._time 383 384 def advance_time(self, advance): 385 """Move test time forward.""" 386 if advance: 387 self._time += advance 388 389 def close(self): 390 super().close() 391 if self._check_on_close: 392 try: 393 self._gen.send(0) 394 except StopIteration: 395 pass 396 else: # pragma: no cover 397 raise AssertionError("Time generator is not finished") 398 399 def _add_reader(self, fd, callback, *args): 400 self.readers[fd] = events.Handle(callback, args, self, None) 401 402 def _remove_reader(self, fd): 403 self.remove_reader_count[fd] += 1 404 if fd in self.readers: 405 del self.readers[fd] 406 return True 407 else: 408 return False 409 410 def assert_reader(self, fd, callback, *args): 411 if fd not in self.readers: 412 raise AssertionError(f'fd {fd} is not registered') 413 handle = self.readers[fd] 414 if handle._callback != callback: 415 raise AssertionError( 416 f'unexpected callback: {handle._callback} != {callback}') 417 if handle._args != args: 418 raise AssertionError( 419 f'unexpected callback args: {handle._args} != {args}') 420 421 def assert_no_reader(self, fd): 422 if fd in self.readers: 423 raise AssertionError(f'fd {fd} is registered') 424 425 def _add_writer(self, fd, callback, *args): 426 self.writers[fd] = events.Handle(callback, args, self, None) 427 428 def _remove_writer(self, fd): 429 self.remove_writer_count[fd] += 1 430 if fd in self.writers: 431 del self.writers[fd] 432 return True 433 else: 434 return False 435 436 def assert_writer(self, fd, callback, *args): 437 if fd not in self.writers: 438 raise AssertionError(f'fd {fd} is not registered') 439 handle = self.writers[fd] 440 if handle._callback != callback: 441 raise AssertionError(f'{handle._callback!r} != {callback!r}') 442 if handle._args != args: 443 raise AssertionError(f'{handle._args!r} != {args!r}') 444 445 def _ensure_fd_no_transport(self, fd): 446 if not isinstance(fd, int): 447 try: 448 fd = int(fd.fileno()) 449 except (AttributeError, TypeError, ValueError): 450 # This code matches selectors._fileobj_to_fd function. 451 raise ValueError("Invalid file object: " 452 "{!r}".format(fd)) from None 453 try: 454 transport = self._transports[fd] 455 except KeyError: 456 pass 457 else: 458 raise RuntimeError( 459 'File descriptor {!r} is used by transport {!r}'.format( 460 fd, transport)) 461 462 def add_reader(self, fd, callback, *args): 463 """Add a reader callback.""" 464 self._ensure_fd_no_transport(fd) 465 return self._add_reader(fd, callback, *args) 466 467 def remove_reader(self, fd): 468 """Remove a reader callback.""" 469 self._ensure_fd_no_transport(fd) 470 return self._remove_reader(fd) 471 472 def add_writer(self, fd, callback, *args): 473 """Add a writer callback..""" 474 self._ensure_fd_no_transport(fd) 475 return self._add_writer(fd, callback, *args) 476 477 def remove_writer(self, fd): 478 """Remove a writer callback.""" 479 self._ensure_fd_no_transport(fd) 480 return self._remove_writer(fd) 481 482 def reset_counters(self): 483 self.remove_reader_count = collections.defaultdict(int) 484 self.remove_writer_count = collections.defaultdict(int) 485 486 def _run_once(self): 487 super()._run_once() 488 for when in self._timers: 489 advance = self._gen.send(when) 490 self.advance_time(advance) 491 self._timers = [] 492 493 def call_at(self, when, callback, *args, context=None): 494 self._timers.append(when) 495 return super().call_at(when, callback, *args, context=context) 496 497 def _process_events(self, event_list): 498 return 499 500 def _write_to_self(self): 501 pass 502 503 504def MockCallback(**kwargs): 505 return mock.Mock(spec=['__call__'], **kwargs) 506 507 508class MockPattern(str): 509 """A regex based str with a fuzzy __eq__. 510 511 Use this helper with 'mock.assert_called_with', or anywhere 512 where a regex comparison between strings is needed. 513 514 For instance: 515 mock_call.assert_called_with(MockPattern('spam.*ham')) 516 """ 517 def __eq__(self, other): 518 return bool(re.search(str(self), other, re.S)) 519 520 521class MockInstanceOf: 522 def __init__(self, type): 523 self._type = type 524 525 def __eq__(self, other): 526 return isinstance(other, self._type) 527 528 529def get_function_source(func): 530 source = format_helpers._get_function_source(func) 531 if source is None: 532 raise ValueError("unable to get the source of %r" % (func,)) 533 return source 534 535 536class TestCase(unittest.TestCase): 537 @staticmethod 538 def close_loop(loop): 539 if loop._default_executor is not None: 540 if not loop.is_closed(): 541 loop.run_until_complete(loop.shutdown_default_executor()) 542 else: 543 loop._default_executor.shutdown(wait=True) 544 loop.close() 545 policy = support.maybe_get_event_loop_policy() 546 if policy is not None: 547 try: 548 watcher = policy.get_child_watcher() 549 except NotImplementedError: 550 # watcher is not implemented by EventLoopPolicy, e.g. Windows 551 pass 552 else: 553 if isinstance(watcher, asyncio.ThreadedChildWatcher): 554 threads = list(watcher._threads.values()) 555 for thread in threads: 556 thread.join() 557 558 def set_event_loop(self, loop, *, cleanup=True): 559 if loop is None: 560 raise AssertionError('loop is None') 561 # ensure that the event loop is passed explicitly in asyncio 562 events.set_event_loop(None) 563 if cleanup: 564 self.addCleanup(self.close_loop, loop) 565 566 def new_test_loop(self, gen=None): 567 loop = TestLoop(gen) 568 self.set_event_loop(loop) 569 return loop 570 571 def setUp(self): 572 self._thread_cleanup = threading_helper.threading_setup() 573 574 def tearDown(self): 575 events.set_event_loop(None) 576 577 # Detect CPython bug #23353: ensure that yield/yield-from is not used 578 # in an except block of a generator 579 self.assertEqual(sys.exc_info(), (None, None, None)) 580 581 self.doCleanups() 582 threading_helper.threading_cleanup(*self._thread_cleanup) 583 support.reap_children() 584 585 586@contextlib.contextmanager 587def disable_logger(): 588 """Context manager to disable asyncio logger. 589 590 For example, it can be used to ignore warnings in debug mode. 591 """ 592 old_level = logger.level 593 try: 594 logger.setLevel(logging.CRITICAL+1) 595 yield 596 finally: 597 logger.setLevel(old_level) 598 599 600def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, 601 family=socket.AF_INET): 602 """Create a mock of a non-blocking socket.""" 603 sock = mock.MagicMock(socket.socket) 604 sock.proto = proto 605 sock.type = type 606 sock.family = family 607 sock.gettimeout.return_value = 0.0 608 return sock 609