1"""Test case implementation""" 2 3import sys 4import functools 5import difflib 6import pprint 7import re 8import warnings 9import collections 10import contextlib 11import traceback 12import types 13 14from . import result 15from .util import (strclass, safe_repr, _count_diff_all_purpose, 16 _count_diff_hashable, _common_shorten_repr) 17 18__unittest = True 19 20_subtest_msg_sentinel = object() 21 22DIFF_OMITTED = ('\nDiff is %s characters long. ' 23 'Set self.maxDiff to None to see it.') 24 25class SkipTest(Exception): 26 """ 27 Raise this exception in a test to skip it. 28 29 Usually you can use TestCase.skipTest() or one of the skipping decorators 30 instead of raising this directly. 31 """ 32 33class _ShouldStop(Exception): 34 """ 35 The test should stop. 36 """ 37 38class _UnexpectedSuccess(Exception): 39 """ 40 The test was supposed to fail, but it didn't! 41 """ 42 43 44class _Outcome(object): 45 def __init__(self, result=None): 46 self.expecting_failure = False 47 self.result = result 48 self.result_supports_subtests = hasattr(result, "addSubTest") 49 self.success = True 50 self.expectedFailure = None 51 52 @contextlib.contextmanager 53 def testPartExecutor(self, test_case, subTest=False): 54 old_success = self.success 55 self.success = True 56 try: 57 yield 58 except KeyboardInterrupt: 59 raise 60 except SkipTest as e: 61 self.success = False 62 _addSkip(self.result, test_case, str(e)) 63 except _ShouldStop: 64 pass 65 except: 66 exc_info = sys.exc_info() 67 if self.expecting_failure: 68 self.expectedFailure = exc_info 69 else: 70 self.success = False 71 if subTest: 72 self.result.addSubTest(test_case.test_case, test_case, exc_info) 73 else: 74 _addError(self.result, test_case, exc_info) 75 # explicitly break a reference cycle: 76 # exc_info -> frame -> exc_info 77 exc_info = None 78 else: 79 if subTest and self.success: 80 self.result.addSubTest(test_case.test_case, test_case, None) 81 finally: 82 self.success = self.success and old_success 83 84 85def _addSkip(result, test_case, reason): 86 addSkip = getattr(result, 'addSkip', None) 87 if addSkip is not None: 88 addSkip(test_case, reason) 89 else: 90 warnings.warn("TestResult has no addSkip method, skips not reported", 91 RuntimeWarning, 2) 92 result.addSuccess(test_case) 93 94def _addError(result, test, exc_info): 95 if result is not None and exc_info is not None: 96 if issubclass(exc_info[0], test.failureException): 97 result.addFailure(test, exc_info) 98 else: 99 result.addError(test, exc_info) 100 101def _id(obj): 102 return obj 103 104 105def _enter_context(cm, addcleanup): 106 # We look up the special methods on the type to match the with 107 # statement. 108 cls = type(cm) 109 try: 110 enter = cls.__enter__ 111 exit = cls.__exit__ 112 except AttributeError: 113 raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does " 114 f"not support the context manager protocol") from None 115 result = enter(cm) 116 addcleanup(exit, cm, None, None, None) 117 return result 118 119 120_module_cleanups = [] 121def addModuleCleanup(function, /, *args, **kwargs): 122 """Same as addCleanup, except the cleanup items are called even if 123 setUpModule fails (unlike tearDownModule).""" 124 _module_cleanups.append((function, args, kwargs)) 125 126def enterModuleContext(cm): 127 """Same as enterContext, but module-wide.""" 128 return _enter_context(cm, addModuleCleanup) 129 130 131def doModuleCleanups(): 132 """Execute all module cleanup functions. Normally called for you after 133 tearDownModule.""" 134 exceptions = [] 135 while _module_cleanups: 136 function, args, kwargs = _module_cleanups.pop() 137 try: 138 function(*args, **kwargs) 139 except Exception as exc: 140 exceptions.append(exc) 141 if exceptions: 142 # Swallows all but first exception. If a multi-exception handler 143 # gets written we should use that here instead. 144 raise exceptions[0] 145 146 147def skip(reason): 148 """ 149 Unconditionally skip a test. 150 """ 151 def decorator(test_item): 152 if not isinstance(test_item, type): 153 @functools.wraps(test_item) 154 def skip_wrapper(*args, **kwargs): 155 raise SkipTest(reason) 156 test_item = skip_wrapper 157 158 test_item.__unittest_skip__ = True 159 test_item.__unittest_skip_why__ = reason 160 return test_item 161 if isinstance(reason, types.FunctionType): 162 test_item = reason 163 reason = '' 164 return decorator(test_item) 165 return decorator 166 167def skipIf(condition, reason): 168 """ 169 Skip a test if the condition is true. 170 """ 171 if condition: 172 return skip(reason) 173 return _id 174 175def skipUnless(condition, reason): 176 """ 177 Skip a test unless the condition is true. 178 """ 179 if not condition: 180 return skip(reason) 181 return _id 182 183def expectedFailure(test_item): 184 test_item.__unittest_expecting_failure__ = True 185 return test_item 186 187def _is_subtype(expected, basetype): 188 if isinstance(expected, tuple): 189 return all(_is_subtype(e, basetype) for e in expected) 190 return isinstance(expected, type) and issubclass(expected, basetype) 191 192class _BaseTestCaseContext: 193 194 def __init__(self, test_case): 195 self.test_case = test_case 196 197 def _raiseFailure(self, standardMsg): 198 msg = self.test_case._formatMessage(self.msg, standardMsg) 199 raise self.test_case.failureException(msg) 200 201class _AssertRaisesBaseContext(_BaseTestCaseContext): 202 203 def __init__(self, expected, test_case, expected_regex=None): 204 _BaseTestCaseContext.__init__(self, test_case) 205 self.expected = expected 206 self.test_case = test_case 207 if expected_regex is not None: 208 expected_regex = re.compile(expected_regex) 209 self.expected_regex = expected_regex 210 self.obj_name = None 211 self.msg = None 212 213 def handle(self, name, args, kwargs): 214 """ 215 If args is empty, assertRaises/Warns is being used as a 216 context manager, so check for a 'msg' kwarg and return self. 217 If args is not empty, call a callable passing positional and keyword 218 arguments. 219 """ 220 try: 221 if not _is_subtype(self.expected, self._base_type): 222 raise TypeError('%s() arg 1 must be %s' % 223 (name, self._base_type_str)) 224 if not args: 225 self.msg = kwargs.pop('msg', None) 226 if kwargs: 227 raise TypeError('%r is an invalid keyword argument for ' 228 'this function' % (next(iter(kwargs)),)) 229 return self 230 231 callable_obj, *args = args 232 try: 233 self.obj_name = callable_obj.__name__ 234 except AttributeError: 235 self.obj_name = str(callable_obj) 236 with self: 237 callable_obj(*args, **kwargs) 238 finally: 239 # bpo-23890: manually break a reference cycle 240 self = None 241 242 243class _AssertRaisesContext(_AssertRaisesBaseContext): 244 """A context manager used to implement TestCase.assertRaises* methods.""" 245 246 _base_type = BaseException 247 _base_type_str = 'an exception type or tuple of exception types' 248 249 def __enter__(self): 250 return self 251 252 def __exit__(self, exc_type, exc_value, tb): 253 if exc_type is None: 254 try: 255 exc_name = self.expected.__name__ 256 except AttributeError: 257 exc_name = str(self.expected) 258 if self.obj_name: 259 self._raiseFailure("{} not raised by {}".format(exc_name, 260 self.obj_name)) 261 else: 262 self._raiseFailure("{} not raised".format(exc_name)) 263 else: 264 traceback.clear_frames(tb) 265 if not issubclass(exc_type, self.expected): 266 # let unexpected exceptions pass through 267 return False 268 # store exception, without traceback, for later retrieval 269 self.exception = exc_value.with_traceback(None) 270 if self.expected_regex is None: 271 return True 272 273 expected_regex = self.expected_regex 274 if not expected_regex.search(str(exc_value)): 275 self._raiseFailure('"{}" does not match "{}"'.format( 276 expected_regex.pattern, str(exc_value))) 277 return True 278 279 __class_getitem__ = classmethod(types.GenericAlias) 280 281 282class _AssertWarnsContext(_AssertRaisesBaseContext): 283 """A context manager used to implement TestCase.assertWarns* methods.""" 284 285 _base_type = Warning 286 _base_type_str = 'a warning type or tuple of warning types' 287 288 def __enter__(self): 289 # The __warningregistry__'s need to be in a pristine state for tests 290 # to work properly. 291 for v in list(sys.modules.values()): 292 if getattr(v, '__warningregistry__', None): 293 v.__warningregistry__ = {} 294 self.warnings_manager = warnings.catch_warnings(record=True) 295 self.warnings = self.warnings_manager.__enter__() 296 warnings.simplefilter("always", self.expected) 297 return self 298 299 def __exit__(self, exc_type, exc_value, tb): 300 self.warnings_manager.__exit__(exc_type, exc_value, tb) 301 if exc_type is not None: 302 # let unexpected exceptions pass through 303 return 304 try: 305 exc_name = self.expected.__name__ 306 except AttributeError: 307 exc_name = str(self.expected) 308 first_matching = None 309 for m in self.warnings: 310 w = m.message 311 if not isinstance(w, self.expected): 312 continue 313 if first_matching is None: 314 first_matching = w 315 if (self.expected_regex is not None and 316 not self.expected_regex.search(str(w))): 317 continue 318 # store warning for later retrieval 319 self.warning = w 320 self.filename = m.filename 321 self.lineno = m.lineno 322 return 323 # Now we simply try to choose a helpful failure message 324 if first_matching is not None: 325 self._raiseFailure('"{}" does not match "{}"'.format( 326 self.expected_regex.pattern, str(first_matching))) 327 if self.obj_name: 328 self._raiseFailure("{} not triggered by {}".format(exc_name, 329 self.obj_name)) 330 else: 331 self._raiseFailure("{} not triggered".format(exc_name)) 332 333 334class _OrderedChainMap(collections.ChainMap): 335 def __iter__(self): 336 seen = set() 337 for mapping in self.maps: 338 for k in mapping: 339 if k not in seen: 340 seen.add(k) 341 yield k 342 343 344class TestCase(object): 345 """A class whose instances are single test cases. 346 347 By default, the test code itself should be placed in a method named 348 'runTest'. 349 350 If the fixture may be used for many test cases, create as 351 many test methods as are needed. When instantiating such a TestCase 352 subclass, specify in the constructor arguments the name of the test method 353 that the instance is to execute. 354 355 Test authors should subclass TestCase for their own tests. Construction 356 and deconstruction of the test's environment ('fixture') can be 357 implemented by overriding the 'setUp' and 'tearDown' methods respectively. 358 359 If it is necessary to override the __init__ method, the base class 360 __init__ method must always be called. It is important that subclasses 361 should not change the signature of their __init__ method, since instances 362 of the classes are instantiated automatically by parts of the framework 363 in order to be run. 364 365 When subclassing TestCase, you can set these attributes: 366 * failureException: determines which exception will be raised when 367 the instance's assertion methods fail; test methods raising this 368 exception will be deemed to have 'failed' rather than 'errored'. 369 * longMessage: determines whether long messages (including repr of 370 objects used in assert methods) will be printed on failure in *addition* 371 to any explicit message passed. 372 * maxDiff: sets the maximum length of a diff in failure messages 373 by assert methods using difflib. It is looked up as an instance 374 attribute so can be configured by individual tests if required. 375 """ 376 377 failureException = AssertionError 378 379 longMessage = True 380 381 maxDiff = 80*8 382 383 # If a string is longer than _diffThreshold, use normal comparison instead 384 # of difflib. See #11763. 385 _diffThreshold = 2**16 386 387 def __init_subclass__(cls, *args, **kwargs): 388 # Attribute used by TestSuite for classSetUp 389 cls._classSetupFailed = False 390 cls._class_cleanups = [] 391 super().__init_subclass__(*args, **kwargs) 392 393 def __init__(self, methodName='runTest'): 394 """Create an instance of the class that will use the named test 395 method when executed. Raises a ValueError if the instance does 396 not have a method with the specified name. 397 """ 398 self._testMethodName = methodName 399 self._outcome = None 400 self._testMethodDoc = 'No test' 401 try: 402 testMethod = getattr(self, methodName) 403 except AttributeError: 404 if methodName != 'runTest': 405 # we allow instantiation with no explicit method name 406 # but not an *incorrect* or missing method name 407 raise ValueError("no such test method in %s: %s" % 408 (self.__class__, methodName)) 409 else: 410 self._testMethodDoc = testMethod.__doc__ 411 self._cleanups = [] 412 self._subtest = None 413 414 # Map types to custom assertEqual functions that will compare 415 # instances of said type in more detail to generate a more useful 416 # error message. 417 self._type_equality_funcs = {} 418 self.addTypeEqualityFunc(dict, 'assertDictEqual') 419 self.addTypeEqualityFunc(list, 'assertListEqual') 420 self.addTypeEqualityFunc(tuple, 'assertTupleEqual') 421 self.addTypeEqualityFunc(set, 'assertSetEqual') 422 self.addTypeEqualityFunc(frozenset, 'assertSetEqual') 423 self.addTypeEqualityFunc(str, 'assertMultiLineEqual') 424 425 def addTypeEqualityFunc(self, typeobj, function): 426 """Add a type specific assertEqual style function to compare a type. 427 428 This method is for use by TestCase subclasses that need to register 429 their own type equality functions to provide nicer error messages. 430 431 Args: 432 typeobj: The data type to call this function on when both values 433 are of the same type in assertEqual(). 434 function: The callable taking two arguments and an optional 435 msg= argument that raises self.failureException with a 436 useful error message when the two arguments are not equal. 437 """ 438 self._type_equality_funcs[typeobj] = function 439 440 def addCleanup(self, function, /, *args, **kwargs): 441 """Add a function, with arguments, to be called when the test is 442 completed. Functions added are called on a LIFO basis and are 443 called after tearDown on test failure or success. 444 445 Cleanup items are called even if setUp fails (unlike tearDown).""" 446 self._cleanups.append((function, args, kwargs)) 447 448 def enterContext(self, cm): 449 """Enters the supplied context manager. 450 451 If successful, also adds its __exit__ method as a cleanup 452 function and returns the result of the __enter__ method. 453 """ 454 return _enter_context(cm, self.addCleanup) 455 456 @classmethod 457 def addClassCleanup(cls, function, /, *args, **kwargs): 458 """Same as addCleanup, except the cleanup items are called even if 459 setUpClass fails (unlike tearDownClass).""" 460 cls._class_cleanups.append((function, args, kwargs)) 461 462 @classmethod 463 def enterClassContext(cls, cm): 464 """Same as enterContext, but class-wide.""" 465 return _enter_context(cm, cls.addClassCleanup) 466 467 def setUp(self): 468 "Hook method for setting up the test fixture before exercising it." 469 pass 470 471 def tearDown(self): 472 "Hook method for deconstructing the test fixture after testing it." 473 pass 474 475 @classmethod 476 def setUpClass(cls): 477 "Hook method for setting up class fixture before running tests in the class." 478 479 @classmethod 480 def tearDownClass(cls): 481 "Hook method for deconstructing the class fixture after running all tests in the class." 482 483 def countTestCases(self): 484 return 1 485 486 def defaultTestResult(self): 487 return result.TestResult() 488 489 def shortDescription(self): 490 """Returns a one-line description of the test, or None if no 491 description has been provided. 492 493 The default implementation of this method returns the first line of 494 the specified test method's docstring. 495 """ 496 doc = self._testMethodDoc 497 return doc.strip().split("\n")[0].strip() if doc else None 498 499 500 def id(self): 501 return "%s.%s" % (strclass(self.__class__), self._testMethodName) 502 503 def __eq__(self, other): 504 if type(self) is not type(other): 505 return NotImplemented 506 507 return self._testMethodName == other._testMethodName 508 509 def __hash__(self): 510 return hash((type(self), self._testMethodName)) 511 512 def __str__(self): 513 return "%s (%s.%s)" % (self._testMethodName, strclass(self.__class__), self._testMethodName) 514 515 def __repr__(self): 516 return "<%s testMethod=%s>" % \ 517 (strclass(self.__class__), self._testMethodName) 518 519 @contextlib.contextmanager 520 def subTest(self, msg=_subtest_msg_sentinel, **params): 521 """Return a context manager that will return the enclosed block 522 of code in a subtest identified by the optional message and 523 keyword parameters. A failure in the subtest marks the test 524 case as failed but resumes execution at the end of the enclosed 525 block, allowing further test code to be executed. 526 """ 527 if self._outcome is None or not self._outcome.result_supports_subtests: 528 yield 529 return 530 parent = self._subtest 531 if parent is None: 532 params_map = _OrderedChainMap(params) 533 else: 534 params_map = parent.params.new_child(params) 535 self._subtest = _SubTest(self, msg, params_map) 536 try: 537 with self._outcome.testPartExecutor(self._subtest, subTest=True): 538 yield 539 if not self._outcome.success: 540 result = self._outcome.result 541 if result is not None and result.failfast: 542 raise _ShouldStop 543 elif self._outcome.expectedFailure: 544 # If the test is expecting a failure, we really want to 545 # stop now and register the expected failure. 546 raise _ShouldStop 547 finally: 548 self._subtest = parent 549 550 def _addExpectedFailure(self, result, exc_info): 551 try: 552 addExpectedFailure = result.addExpectedFailure 553 except AttributeError: 554 warnings.warn("TestResult has no addExpectedFailure method, reporting as passes", 555 RuntimeWarning) 556 result.addSuccess(self) 557 else: 558 addExpectedFailure(self, exc_info) 559 560 def _addUnexpectedSuccess(self, result): 561 try: 562 addUnexpectedSuccess = result.addUnexpectedSuccess 563 except AttributeError: 564 warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure", 565 RuntimeWarning) 566 # We need to pass an actual exception and traceback to addFailure, 567 # otherwise the legacy result can choke. 568 try: 569 raise _UnexpectedSuccess from None 570 except _UnexpectedSuccess: 571 result.addFailure(self, sys.exc_info()) 572 else: 573 addUnexpectedSuccess(self) 574 575 def _callSetUp(self): 576 self.setUp() 577 578 def _callTestMethod(self, method): 579 if method() is not None: 580 warnings.warn(f'It is deprecated to return a value that is not None from a ' 581 f'test case ({method})', DeprecationWarning, stacklevel=3) 582 583 def _callTearDown(self): 584 self.tearDown() 585 586 def _callCleanup(self, function, /, *args, **kwargs): 587 function(*args, **kwargs) 588 589 def run(self, result=None): 590 if result is None: 591 result = self.defaultTestResult() 592 startTestRun = getattr(result, 'startTestRun', None) 593 stopTestRun = getattr(result, 'stopTestRun', None) 594 if startTestRun is not None: 595 startTestRun() 596 else: 597 stopTestRun = None 598 599 result.startTest(self) 600 try: 601 testMethod = getattr(self, self._testMethodName) 602 if (getattr(self.__class__, "__unittest_skip__", False) or 603 getattr(testMethod, "__unittest_skip__", False)): 604 # If the class or method was skipped. 605 skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') 606 or getattr(testMethod, '__unittest_skip_why__', '')) 607 _addSkip(result, self, skip_why) 608 return result 609 610 expecting_failure = ( 611 getattr(self, "__unittest_expecting_failure__", False) or 612 getattr(testMethod, "__unittest_expecting_failure__", False) 613 ) 614 outcome = _Outcome(result) 615 try: 616 self._outcome = outcome 617 618 with outcome.testPartExecutor(self): 619 self._callSetUp() 620 if outcome.success: 621 outcome.expecting_failure = expecting_failure 622 with outcome.testPartExecutor(self): 623 self._callTestMethod(testMethod) 624 outcome.expecting_failure = False 625 with outcome.testPartExecutor(self): 626 self._callTearDown() 627 self.doCleanups() 628 629 if outcome.success: 630 if expecting_failure: 631 if outcome.expectedFailure: 632 self._addExpectedFailure(result, outcome.expectedFailure) 633 else: 634 self._addUnexpectedSuccess(result) 635 else: 636 result.addSuccess(self) 637 return result 638 finally: 639 # explicitly break reference cycle: 640 # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure 641 outcome.expectedFailure = None 642 outcome = None 643 644 # clear the outcome, no more needed 645 self._outcome = None 646 647 finally: 648 result.stopTest(self) 649 if stopTestRun is not None: 650 stopTestRun() 651 652 def doCleanups(self): 653 """Execute all cleanup functions. Normally called for you after 654 tearDown.""" 655 outcome = self._outcome or _Outcome() 656 while self._cleanups: 657 function, args, kwargs = self._cleanups.pop() 658 with outcome.testPartExecutor(self): 659 self._callCleanup(function, *args, **kwargs) 660 661 # return this for backwards compatibility 662 # even though we no longer use it internally 663 return outcome.success 664 665 @classmethod 666 def doClassCleanups(cls): 667 """Execute all class cleanup functions. Normally called for you after 668 tearDownClass.""" 669 cls.tearDown_exceptions = [] 670 while cls._class_cleanups: 671 function, args, kwargs = cls._class_cleanups.pop() 672 try: 673 function(*args, **kwargs) 674 except Exception: 675 cls.tearDown_exceptions.append(sys.exc_info()) 676 677 def __call__(self, *args, **kwds): 678 return self.run(*args, **kwds) 679 680 def debug(self): 681 """Run the test without collecting errors in a TestResult""" 682 testMethod = getattr(self, self._testMethodName) 683 if (getattr(self.__class__, "__unittest_skip__", False) or 684 getattr(testMethod, "__unittest_skip__", False)): 685 # If the class or method was skipped. 686 skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') 687 or getattr(testMethod, '__unittest_skip_why__', '')) 688 raise SkipTest(skip_why) 689 690 self._callSetUp() 691 self._callTestMethod(testMethod) 692 self._callTearDown() 693 while self._cleanups: 694 function, args, kwargs = self._cleanups.pop() 695 self._callCleanup(function, *args, **kwargs) 696 697 def skipTest(self, reason): 698 """Skip this test.""" 699 raise SkipTest(reason) 700 701 def fail(self, msg=None): 702 """Fail immediately, with the given message.""" 703 raise self.failureException(msg) 704 705 def assertFalse(self, expr, msg=None): 706 """Check that the expression is false.""" 707 if expr: 708 msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr)) 709 raise self.failureException(msg) 710 711 def assertTrue(self, expr, msg=None): 712 """Check that the expression is true.""" 713 if not expr: 714 msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr)) 715 raise self.failureException(msg) 716 717 def _formatMessage(self, msg, standardMsg): 718 """Honour the longMessage attribute when generating failure messages. 719 If longMessage is False this means: 720 * Use only an explicit message if it is provided 721 * Otherwise use the standard message for the assert 722 723 If longMessage is True: 724 * Use the standard message 725 * If an explicit message is provided, plus ' : ' and the explicit message 726 """ 727 if not self.longMessage: 728 return msg or standardMsg 729 if msg is None: 730 return standardMsg 731 try: 732 # don't switch to '{}' formatting in Python 2.X 733 # it changes the way unicode input is handled 734 return '%s : %s' % (standardMsg, msg) 735 except UnicodeDecodeError: 736 return '%s : %s' % (safe_repr(standardMsg), safe_repr(msg)) 737 738 def assertRaises(self, expected_exception, *args, **kwargs): 739 """Fail unless an exception of class expected_exception is raised 740 by the callable when invoked with specified positional and 741 keyword arguments. If a different type of exception is 742 raised, it will not be caught, and the test case will be 743 deemed to have suffered an error, exactly as for an 744 unexpected exception. 745 746 If called with the callable and arguments omitted, will return a 747 context object used like this:: 748 749 with self.assertRaises(SomeException): 750 do_something() 751 752 An optional keyword argument 'msg' can be provided when assertRaises 753 is used as a context object. 754 755 The context manager keeps a reference to the exception as 756 the 'exception' attribute. This allows you to inspect the 757 exception after the assertion:: 758 759 with self.assertRaises(SomeException) as cm: 760 do_something() 761 the_exception = cm.exception 762 self.assertEqual(the_exception.error_code, 3) 763 """ 764 context = _AssertRaisesContext(expected_exception, self) 765 try: 766 return context.handle('assertRaises', args, kwargs) 767 finally: 768 # bpo-23890: manually break a reference cycle 769 context = None 770 771 def assertWarns(self, expected_warning, *args, **kwargs): 772 """Fail unless a warning of class warnClass is triggered 773 by the callable when invoked with specified positional and 774 keyword arguments. If a different type of warning is 775 triggered, it will not be handled: depending on the other 776 warning filtering rules in effect, it might be silenced, printed 777 out, or raised as an exception. 778 779 If called with the callable and arguments omitted, will return a 780 context object used like this:: 781 782 with self.assertWarns(SomeWarning): 783 do_something() 784 785 An optional keyword argument 'msg' can be provided when assertWarns 786 is used as a context object. 787 788 The context manager keeps a reference to the first matching 789 warning as the 'warning' attribute; similarly, the 'filename' 790 and 'lineno' attributes give you information about the line 791 of Python code from which the warning was triggered. 792 This allows you to inspect the warning after the assertion:: 793 794 with self.assertWarns(SomeWarning) as cm: 795 do_something() 796 the_warning = cm.warning 797 self.assertEqual(the_warning.some_attribute, 147) 798 """ 799 context = _AssertWarnsContext(expected_warning, self) 800 return context.handle('assertWarns', args, kwargs) 801 802 def assertLogs(self, logger=None, level=None): 803 """Fail unless a log message of level *level* or higher is emitted 804 on *logger_name* or its children. If omitted, *level* defaults to 805 INFO and *logger* defaults to the root logger. 806 807 This method must be used as a context manager, and will yield 808 a recording object with two attributes: `output` and `records`. 809 At the end of the context manager, the `output` attribute will 810 be a list of the matching formatted log messages and the 811 `records` attribute will be a list of the corresponding LogRecord 812 objects. 813 814 Example:: 815 816 with self.assertLogs('foo', level='INFO') as cm: 817 logging.getLogger('foo').info('first message') 818 logging.getLogger('foo.bar').error('second message') 819 self.assertEqual(cm.output, ['INFO:foo:first message', 820 'ERROR:foo.bar:second message']) 821 """ 822 # Lazy import to avoid importing logging if it is not needed. 823 from ._log import _AssertLogsContext 824 return _AssertLogsContext(self, logger, level, no_logs=False) 825 826 def assertNoLogs(self, logger=None, level=None): 827 """ Fail unless no log messages of level *level* or higher are emitted 828 on *logger_name* or its children. 829 830 This method must be used as a context manager. 831 """ 832 from ._log import _AssertLogsContext 833 return _AssertLogsContext(self, logger, level, no_logs=True) 834 835 def _getAssertEqualityFunc(self, first, second): 836 """Get a detailed comparison function for the types of the two args. 837 838 Returns: A callable accepting (first, second, msg=None) that will 839 raise a failure exception if first != second with a useful human 840 readable error message for those types. 841 """ 842 # 843 # NOTE(gregory.p.smith): I considered isinstance(first, type(second)) 844 # and vice versa. I opted for the conservative approach in case 845 # subclasses are not intended to be compared in detail to their super 846 # class instances using a type equality func. This means testing 847 # subtypes won't automagically use the detailed comparison. Callers 848 # should use their type specific assertSpamEqual method to compare 849 # subclasses if the detailed comparison is desired and appropriate. 850 # See the discussion in http://bugs.python.org/issue2578. 851 # 852 if type(first) is type(second): 853 asserter = self._type_equality_funcs.get(type(first)) 854 if asserter is not None: 855 if isinstance(asserter, str): 856 asserter = getattr(self, asserter) 857 return asserter 858 859 return self._baseAssertEqual 860 861 def _baseAssertEqual(self, first, second, msg=None): 862 """The default assertEqual implementation, not type specific.""" 863 if not first == second: 864 standardMsg = '%s != %s' % _common_shorten_repr(first, second) 865 msg = self._formatMessage(msg, standardMsg) 866 raise self.failureException(msg) 867 868 def assertEqual(self, first, second, msg=None): 869 """Fail if the two objects are unequal as determined by the '==' 870 operator. 871 """ 872 assertion_func = self._getAssertEqualityFunc(first, second) 873 assertion_func(first, second, msg=msg) 874 875 def assertNotEqual(self, first, second, msg=None): 876 """Fail if the two objects are equal as determined by the '!=' 877 operator. 878 """ 879 if not first != second: 880 msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first), 881 safe_repr(second))) 882 raise self.failureException(msg) 883 884 def assertAlmostEqual(self, first, second, places=None, msg=None, 885 delta=None): 886 """Fail if the two objects are unequal as determined by their 887 difference rounded to the given number of decimal places 888 (default 7) and comparing to zero, or by comparing that the 889 difference between the two objects is more than the given 890 delta. 891 892 Note that decimal places (from zero) are usually not the same 893 as significant digits (measured from the most significant digit). 894 895 If the two objects compare equal then they will automatically 896 compare almost equal. 897 """ 898 if first == second: 899 # shortcut 900 return 901 if delta is not None and places is not None: 902 raise TypeError("specify delta or places not both") 903 904 diff = abs(first - second) 905 if delta is not None: 906 if diff <= delta: 907 return 908 909 standardMsg = '%s != %s within %s delta (%s difference)' % ( 910 safe_repr(first), 911 safe_repr(second), 912 safe_repr(delta), 913 safe_repr(diff)) 914 else: 915 if places is None: 916 places = 7 917 918 if round(diff, places) == 0: 919 return 920 921 standardMsg = '%s != %s within %r places (%s difference)' % ( 922 safe_repr(first), 923 safe_repr(second), 924 places, 925 safe_repr(diff)) 926 msg = self._formatMessage(msg, standardMsg) 927 raise self.failureException(msg) 928 929 def assertNotAlmostEqual(self, first, second, places=None, msg=None, 930 delta=None): 931 """Fail if the two objects are equal as determined by their 932 difference rounded to the given number of decimal places 933 (default 7) and comparing to zero, or by comparing that the 934 difference between the two objects is less than the given delta. 935 936 Note that decimal places (from zero) are usually not the same 937 as significant digits (measured from the most significant digit). 938 939 Objects that are equal automatically fail. 940 """ 941 if delta is not None and places is not None: 942 raise TypeError("specify delta or places not both") 943 diff = abs(first - second) 944 if delta is not None: 945 if not (first == second) and diff > delta: 946 return 947 standardMsg = '%s == %s within %s delta (%s difference)' % ( 948 safe_repr(first), 949 safe_repr(second), 950 safe_repr(delta), 951 safe_repr(diff)) 952 else: 953 if places is None: 954 places = 7 955 if not (first == second) and round(diff, places) != 0: 956 return 957 standardMsg = '%s == %s within %r places' % (safe_repr(first), 958 safe_repr(second), 959 places) 960 961 msg = self._formatMessage(msg, standardMsg) 962 raise self.failureException(msg) 963 964 def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): 965 """An equality assertion for ordered sequences (like lists and tuples). 966 967 For the purposes of this function, a valid ordered sequence type is one 968 which can be indexed, has a length, and has an equality operator. 969 970 Args: 971 seq1: The first sequence to compare. 972 seq2: The second sequence to compare. 973 seq_type: The expected datatype of the sequences, or None if no 974 datatype should be enforced. 975 msg: Optional message to use on failure instead of a list of 976 differences. 977 """ 978 if seq_type is not None: 979 seq_type_name = seq_type.__name__ 980 if not isinstance(seq1, seq_type): 981 raise self.failureException('First sequence is not a %s: %s' 982 % (seq_type_name, safe_repr(seq1))) 983 if not isinstance(seq2, seq_type): 984 raise self.failureException('Second sequence is not a %s: %s' 985 % (seq_type_name, safe_repr(seq2))) 986 else: 987 seq_type_name = "sequence" 988 989 differing = None 990 try: 991 len1 = len(seq1) 992 except (TypeError, NotImplementedError): 993 differing = 'First %s has no length. Non-sequence?' % ( 994 seq_type_name) 995 996 if differing is None: 997 try: 998 len2 = len(seq2) 999 except (TypeError, NotImplementedError): 1000 differing = 'Second %s has no length. Non-sequence?' % ( 1001 seq_type_name) 1002 1003 if differing is None: 1004 if seq1 == seq2: 1005 return 1006 1007 differing = '%ss differ: %s != %s\n' % ( 1008 (seq_type_name.capitalize(),) + 1009 _common_shorten_repr(seq1, seq2)) 1010 1011 for i in range(min(len1, len2)): 1012 try: 1013 item1 = seq1[i] 1014 except (TypeError, IndexError, NotImplementedError): 1015 differing += ('\nUnable to index element %d of first %s\n' % 1016 (i, seq_type_name)) 1017 break 1018 1019 try: 1020 item2 = seq2[i] 1021 except (TypeError, IndexError, NotImplementedError): 1022 differing += ('\nUnable to index element %d of second %s\n' % 1023 (i, seq_type_name)) 1024 break 1025 1026 if item1 != item2: 1027 differing += ('\nFirst differing element %d:\n%s\n%s\n' % 1028 ((i,) + _common_shorten_repr(item1, item2))) 1029 break 1030 else: 1031 if (len1 == len2 and seq_type is None and 1032 type(seq1) != type(seq2)): 1033 # The sequences are the same, but have differing types. 1034 return 1035 1036 if len1 > len2: 1037 differing += ('\nFirst %s contains %d additional ' 1038 'elements.\n' % (seq_type_name, len1 - len2)) 1039 try: 1040 differing += ('First extra element %d:\n%s\n' % 1041 (len2, safe_repr(seq1[len2]))) 1042 except (TypeError, IndexError, NotImplementedError): 1043 differing += ('Unable to index element %d ' 1044 'of first %s\n' % (len2, seq_type_name)) 1045 elif len1 < len2: 1046 differing += ('\nSecond %s contains %d additional ' 1047 'elements.\n' % (seq_type_name, len2 - len1)) 1048 try: 1049 differing += ('First extra element %d:\n%s\n' % 1050 (len1, safe_repr(seq2[len1]))) 1051 except (TypeError, IndexError, NotImplementedError): 1052 differing += ('Unable to index element %d ' 1053 'of second %s\n' % (len1, seq_type_name)) 1054 standardMsg = differing 1055 diffMsg = '\n' + '\n'.join( 1056 difflib.ndiff(pprint.pformat(seq1).splitlines(), 1057 pprint.pformat(seq2).splitlines())) 1058 1059 standardMsg = self._truncateMessage(standardMsg, diffMsg) 1060 msg = self._formatMessage(msg, standardMsg) 1061 self.fail(msg) 1062 1063 def _truncateMessage(self, message, diff): 1064 max_diff = self.maxDiff 1065 if max_diff is None or len(diff) <= max_diff: 1066 return message + diff 1067 return message + (DIFF_OMITTED % len(diff)) 1068 1069 def assertListEqual(self, list1, list2, msg=None): 1070 """A list-specific equality assertion. 1071 1072 Args: 1073 list1: The first list to compare. 1074 list2: The second list to compare. 1075 msg: Optional message to use on failure instead of a list of 1076 differences. 1077 1078 """ 1079 self.assertSequenceEqual(list1, list2, msg, seq_type=list) 1080 1081 def assertTupleEqual(self, tuple1, tuple2, msg=None): 1082 """A tuple-specific equality assertion. 1083 1084 Args: 1085 tuple1: The first tuple to compare. 1086 tuple2: The second tuple to compare. 1087 msg: Optional message to use on failure instead of a list of 1088 differences. 1089 """ 1090 self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple) 1091 1092 def assertSetEqual(self, set1, set2, msg=None): 1093 """A set-specific equality assertion. 1094 1095 Args: 1096 set1: The first set to compare. 1097 set2: The second set to compare. 1098 msg: Optional message to use on failure instead of a list of 1099 differences. 1100 1101 assertSetEqual uses ducktyping to support different types of sets, and 1102 is optimized for sets specifically (parameters must support a 1103 difference method). 1104 """ 1105 try: 1106 difference1 = set1.difference(set2) 1107 except TypeError as e: 1108 self.fail('invalid type when attempting set difference: %s' % e) 1109 except AttributeError as e: 1110 self.fail('first argument does not support set difference: %s' % e) 1111 1112 try: 1113 difference2 = set2.difference(set1) 1114 except TypeError as e: 1115 self.fail('invalid type when attempting set difference: %s' % e) 1116 except AttributeError as e: 1117 self.fail('second argument does not support set difference: %s' % e) 1118 1119 if not (difference1 or difference2): 1120 return 1121 1122 lines = [] 1123 if difference1: 1124 lines.append('Items in the first set but not the second:') 1125 for item in difference1: 1126 lines.append(repr(item)) 1127 if difference2: 1128 lines.append('Items in the second set but not the first:') 1129 for item in difference2: 1130 lines.append(repr(item)) 1131 1132 standardMsg = '\n'.join(lines) 1133 self.fail(self._formatMessage(msg, standardMsg)) 1134 1135 def assertIn(self, member, container, msg=None): 1136 """Just like self.assertTrue(a in b), but with a nicer default message.""" 1137 if member not in container: 1138 standardMsg = '%s not found in %s' % (safe_repr(member), 1139 safe_repr(container)) 1140 self.fail(self._formatMessage(msg, standardMsg)) 1141 1142 def assertNotIn(self, member, container, msg=None): 1143 """Just like self.assertTrue(a not in b), but with a nicer default message.""" 1144 if member in container: 1145 standardMsg = '%s unexpectedly found in %s' % (safe_repr(member), 1146 safe_repr(container)) 1147 self.fail(self._formatMessage(msg, standardMsg)) 1148 1149 def assertIs(self, expr1, expr2, msg=None): 1150 """Just like self.assertTrue(a is b), but with a nicer default message.""" 1151 if expr1 is not expr2: 1152 standardMsg = '%s is not %s' % (safe_repr(expr1), 1153 safe_repr(expr2)) 1154 self.fail(self._formatMessage(msg, standardMsg)) 1155 1156 def assertIsNot(self, expr1, expr2, msg=None): 1157 """Just like self.assertTrue(a is not b), but with a nicer default message.""" 1158 if expr1 is expr2: 1159 standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),) 1160 self.fail(self._formatMessage(msg, standardMsg)) 1161 1162 def assertDictEqual(self, d1, d2, msg=None): 1163 self.assertIsInstance(d1, dict, 'First argument is not a dictionary') 1164 self.assertIsInstance(d2, dict, 'Second argument is not a dictionary') 1165 1166 if d1 != d2: 1167 standardMsg = '%s != %s' % _common_shorten_repr(d1, d2) 1168 diff = ('\n' + '\n'.join(difflib.ndiff( 1169 pprint.pformat(d1).splitlines(), 1170 pprint.pformat(d2).splitlines()))) 1171 standardMsg = self._truncateMessage(standardMsg, diff) 1172 self.fail(self._formatMessage(msg, standardMsg)) 1173 1174 def assertDictContainsSubset(self, subset, dictionary, msg=None): 1175 """Checks whether dictionary is a superset of subset.""" 1176 warnings.warn('assertDictContainsSubset is deprecated', 1177 DeprecationWarning) 1178 missing = [] 1179 mismatched = [] 1180 for key, value in subset.items(): 1181 if key not in dictionary: 1182 missing.append(key) 1183 elif value != dictionary[key]: 1184 mismatched.append('%s, expected: %s, actual: %s' % 1185 (safe_repr(key), safe_repr(value), 1186 safe_repr(dictionary[key]))) 1187 1188 if not (missing or mismatched): 1189 return 1190 1191 standardMsg = '' 1192 if missing: 1193 standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in 1194 missing) 1195 if mismatched: 1196 if standardMsg: 1197 standardMsg += '; ' 1198 standardMsg += 'Mismatched values: %s' % ','.join(mismatched) 1199 1200 self.fail(self._formatMessage(msg, standardMsg)) 1201 1202 1203 def assertCountEqual(self, first, second, msg=None): 1204 """Asserts that two iterables have the same elements, the same number of 1205 times, without regard to order. 1206 1207 self.assertEqual(Counter(list(first)), 1208 Counter(list(second))) 1209 1210 Example: 1211 - [0, 1, 1] and [1, 0, 1] compare equal. 1212 - [0, 0, 1] and [0, 1] compare unequal. 1213 1214 """ 1215 first_seq, second_seq = list(first), list(second) 1216 try: 1217 first = collections.Counter(first_seq) 1218 second = collections.Counter(second_seq) 1219 except TypeError: 1220 # Handle case with unhashable elements 1221 differences = _count_diff_all_purpose(first_seq, second_seq) 1222 else: 1223 if first == second: 1224 return 1225 differences = _count_diff_hashable(first_seq, second_seq) 1226 1227 if differences: 1228 standardMsg = 'Element counts were not equal:\n' 1229 lines = ['First has %d, Second has %d: %r' % diff for diff in differences] 1230 diffMsg = '\n'.join(lines) 1231 standardMsg = self._truncateMessage(standardMsg, diffMsg) 1232 msg = self._formatMessage(msg, standardMsg) 1233 self.fail(msg) 1234 1235 def assertMultiLineEqual(self, first, second, msg=None): 1236 """Assert that two multi-line strings are equal.""" 1237 self.assertIsInstance(first, str, 'First argument is not a string') 1238 self.assertIsInstance(second, str, 'Second argument is not a string') 1239 1240 if first != second: 1241 # don't use difflib if the strings are too long 1242 if (len(first) > self._diffThreshold or 1243 len(second) > self._diffThreshold): 1244 self._baseAssertEqual(first, second, msg) 1245 firstlines = first.splitlines(keepends=True) 1246 secondlines = second.splitlines(keepends=True) 1247 if len(firstlines) == 1 and first.strip('\r\n') == first: 1248 firstlines = [first + '\n'] 1249 secondlines = [second + '\n'] 1250 standardMsg = '%s != %s' % _common_shorten_repr(first, second) 1251 diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines)) 1252 standardMsg = self._truncateMessage(standardMsg, diff) 1253 self.fail(self._formatMessage(msg, standardMsg)) 1254 1255 def assertLess(self, a, b, msg=None): 1256 """Just like self.assertTrue(a < b), but with a nicer default message.""" 1257 if not a < b: 1258 standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b)) 1259 self.fail(self._formatMessage(msg, standardMsg)) 1260 1261 def assertLessEqual(self, a, b, msg=None): 1262 """Just like self.assertTrue(a <= b), but with a nicer default message.""" 1263 if not a <= b: 1264 standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b)) 1265 self.fail(self._formatMessage(msg, standardMsg)) 1266 1267 def assertGreater(self, a, b, msg=None): 1268 """Just like self.assertTrue(a > b), but with a nicer default message.""" 1269 if not a > b: 1270 standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b)) 1271 self.fail(self._formatMessage(msg, standardMsg)) 1272 1273 def assertGreaterEqual(self, a, b, msg=None): 1274 """Just like self.assertTrue(a >= b), but with a nicer default message.""" 1275 if not a >= b: 1276 standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b)) 1277 self.fail(self._formatMessage(msg, standardMsg)) 1278 1279 def assertIsNone(self, obj, msg=None): 1280 """Same as self.assertTrue(obj is None), with a nicer default message.""" 1281 if obj is not None: 1282 standardMsg = '%s is not None' % (safe_repr(obj),) 1283 self.fail(self._formatMessage(msg, standardMsg)) 1284 1285 def assertIsNotNone(self, obj, msg=None): 1286 """Included for symmetry with assertIsNone.""" 1287 if obj is None: 1288 standardMsg = 'unexpectedly None' 1289 self.fail(self._formatMessage(msg, standardMsg)) 1290 1291 def assertIsInstance(self, obj, cls, msg=None): 1292 """Same as self.assertTrue(isinstance(obj, cls)), with a nicer 1293 default message.""" 1294 if not isinstance(obj, cls): 1295 standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls) 1296 self.fail(self._formatMessage(msg, standardMsg)) 1297 1298 def assertNotIsInstance(self, obj, cls, msg=None): 1299 """Included for symmetry with assertIsInstance.""" 1300 if isinstance(obj, cls): 1301 standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls) 1302 self.fail(self._formatMessage(msg, standardMsg)) 1303 1304 def assertRaisesRegex(self, expected_exception, expected_regex, 1305 *args, **kwargs): 1306 """Asserts that the message in a raised exception matches a regex. 1307 1308 Args: 1309 expected_exception: Exception class expected to be raised. 1310 expected_regex: Regex (re.Pattern object or string) expected 1311 to be found in error message. 1312 args: Function to be called and extra positional args. 1313 kwargs: Extra kwargs. 1314 msg: Optional message used in case of failure. Can only be used 1315 when assertRaisesRegex is used as a context manager. 1316 """ 1317 context = _AssertRaisesContext(expected_exception, self, expected_regex) 1318 return context.handle('assertRaisesRegex', args, kwargs) 1319 1320 def assertWarnsRegex(self, expected_warning, expected_regex, 1321 *args, **kwargs): 1322 """Asserts that the message in a triggered warning matches a regexp. 1323 Basic functioning is similar to assertWarns() with the addition 1324 that only warnings whose messages also match the regular expression 1325 are considered successful matches. 1326 1327 Args: 1328 expected_warning: Warning class expected to be triggered. 1329 expected_regex: Regex (re.Pattern object or string) expected 1330 to be found in error message. 1331 args: Function to be called and extra positional args. 1332 kwargs: Extra kwargs. 1333 msg: Optional message used in case of failure. Can only be used 1334 when assertWarnsRegex is used as a context manager. 1335 """ 1336 context = _AssertWarnsContext(expected_warning, self, expected_regex) 1337 return context.handle('assertWarnsRegex', args, kwargs) 1338 1339 def assertRegex(self, text, expected_regex, msg=None): 1340 """Fail the test unless the text matches the regular expression.""" 1341 if isinstance(expected_regex, (str, bytes)): 1342 assert expected_regex, "expected_regex must not be empty." 1343 expected_regex = re.compile(expected_regex) 1344 if not expected_regex.search(text): 1345 standardMsg = "Regex didn't match: %r not found in %r" % ( 1346 expected_regex.pattern, text) 1347 # _formatMessage ensures the longMessage option is respected 1348 msg = self._formatMessage(msg, standardMsg) 1349 raise self.failureException(msg) 1350 1351 def assertNotRegex(self, text, unexpected_regex, msg=None): 1352 """Fail the test if the text matches the regular expression.""" 1353 if isinstance(unexpected_regex, (str, bytes)): 1354 unexpected_regex = re.compile(unexpected_regex) 1355 match = unexpected_regex.search(text) 1356 if match: 1357 standardMsg = 'Regex matched: %r matches %r in %r' % ( 1358 text[match.start() : match.end()], 1359 unexpected_regex.pattern, 1360 text) 1361 # _formatMessage ensures the longMessage option is respected 1362 msg = self._formatMessage(msg, standardMsg) 1363 raise self.failureException(msg) 1364 1365 1366 def _deprecate(original_func): 1367 def deprecated_func(*args, **kwargs): 1368 warnings.warn( 1369 'Please use {0} instead.'.format(original_func.__name__), 1370 DeprecationWarning, 2) 1371 return original_func(*args, **kwargs) 1372 return deprecated_func 1373 1374 # see #9424 1375 failUnlessEqual = assertEquals = _deprecate(assertEqual) 1376 failIfEqual = assertNotEquals = _deprecate(assertNotEqual) 1377 failUnlessAlmostEqual = assertAlmostEquals = _deprecate(assertAlmostEqual) 1378 failIfAlmostEqual = assertNotAlmostEquals = _deprecate(assertNotAlmostEqual) 1379 failUnless = assert_ = _deprecate(assertTrue) 1380 failUnlessRaises = _deprecate(assertRaises) 1381 failIf = _deprecate(assertFalse) 1382 assertRaisesRegexp = _deprecate(assertRaisesRegex) 1383 assertRegexpMatches = _deprecate(assertRegex) 1384 assertNotRegexpMatches = _deprecate(assertNotRegex) 1385 1386 1387 1388class FunctionTestCase(TestCase): 1389 """A test case that wraps a test function. 1390 1391 This is useful for slipping pre-existing test functions into the 1392 unittest framework. Optionally, set-up and tidy-up functions can be 1393 supplied. As with TestCase, the tidy-up ('tearDown') function will 1394 always be called if the set-up ('setUp') function ran successfully. 1395 """ 1396 1397 def __init__(self, testFunc, setUp=None, tearDown=None, description=None): 1398 super(FunctionTestCase, self).__init__() 1399 self._setUpFunc = setUp 1400 self._tearDownFunc = tearDown 1401 self._testFunc = testFunc 1402 self._description = description 1403 1404 def setUp(self): 1405 if self._setUpFunc is not None: 1406 self._setUpFunc() 1407 1408 def tearDown(self): 1409 if self._tearDownFunc is not None: 1410 self._tearDownFunc() 1411 1412 def runTest(self): 1413 self._testFunc() 1414 1415 def id(self): 1416 return self._testFunc.__name__ 1417 1418 def __eq__(self, other): 1419 if not isinstance(other, self.__class__): 1420 return NotImplemented 1421 1422 return self._setUpFunc == other._setUpFunc and \ 1423 self._tearDownFunc == other._tearDownFunc and \ 1424 self._testFunc == other._testFunc and \ 1425 self._description == other._description 1426 1427 def __hash__(self): 1428 return hash((type(self), self._setUpFunc, self._tearDownFunc, 1429 self._testFunc, self._description)) 1430 1431 def __str__(self): 1432 return "%s (%s)" % (strclass(self.__class__), 1433 self._testFunc.__name__) 1434 1435 def __repr__(self): 1436 return "<%s tec=%s>" % (strclass(self.__class__), 1437 self._testFunc) 1438 1439 def shortDescription(self): 1440 if self._description is not None: 1441 return self._description 1442 doc = self._testFunc.__doc__ 1443 return doc and doc.split("\n")[0].strip() or None 1444 1445 1446class _SubTest(TestCase): 1447 1448 def __init__(self, test_case, message, params): 1449 super().__init__() 1450 self._message = message 1451 self.test_case = test_case 1452 self.params = params 1453 self.failureException = test_case.failureException 1454 1455 def runTest(self): 1456 raise NotImplementedError("subtests cannot be run directly") 1457 1458 def _subDescription(self): 1459 parts = [] 1460 if self._message is not _subtest_msg_sentinel: 1461 parts.append("[{}]".format(self._message)) 1462 if self.params: 1463 params_desc = ', '.join( 1464 "{}={!r}".format(k, v) 1465 for (k, v) in self.params.items()) 1466 parts.append("({})".format(params_desc)) 1467 return " ".join(parts) or '(<subtest>)' 1468 1469 def id(self): 1470 return "{} {}".format(self.test_case.id(), self._subDescription()) 1471 1472 def shortDescription(self): 1473 """Returns a one-line description of the subtest, or None if no 1474 description has been provided. 1475 """ 1476 return self.test_case.shortDescription() 1477 1478 def __str__(self): 1479 return "{} {}".format(self.test_case, self._subDescription()) 1480