1"""Test case implementation"""
2
3import sys
4import functools
5import difflib
6import pprint
7import re
8import warnings
9import collections
10import contextlib
11import traceback
12import types
13
14from . import result
15from .util import (strclass, safe_repr, _count_diff_all_purpose,
16                   _count_diff_hashable, _common_shorten_repr)
17
18__unittest = True
19
20_subtest_msg_sentinel = object()
21
22DIFF_OMITTED = ('\nDiff is %s characters long. '
23                 'Set self.maxDiff to None to see it.')
24
25class SkipTest(Exception):
26    """
27    Raise this exception in a test to skip it.
28
29    Usually you can use TestCase.skipTest() or one of the skipping decorators
30    instead of raising this directly.
31    """
32
33class _ShouldStop(Exception):
34    """
35    The test should stop.
36    """
37
38class _UnexpectedSuccess(Exception):
39    """
40    The test was supposed to fail, but it didn't!
41    """
42
43
44class _Outcome(object):
45    def __init__(self, result=None):
46        self.expecting_failure = False
47        self.result = result
48        self.result_supports_subtests = hasattr(result, "addSubTest")
49        self.success = True
50        self.expectedFailure = None
51
52    @contextlib.contextmanager
53    def testPartExecutor(self, test_case, subTest=False):
54        old_success = self.success
55        self.success = True
56        try:
57            yield
58        except KeyboardInterrupt:
59            raise
60        except SkipTest as e:
61            self.success = False
62            _addSkip(self.result, test_case, str(e))
63        except _ShouldStop:
64            pass
65        except:
66            exc_info = sys.exc_info()
67            if self.expecting_failure:
68                self.expectedFailure = exc_info
69            else:
70                self.success = False
71                if subTest:
72                    self.result.addSubTest(test_case.test_case, test_case, exc_info)
73                else:
74                    _addError(self.result, test_case, exc_info)
75            # explicitly break a reference cycle:
76            # exc_info -> frame -> exc_info
77            exc_info = None
78        else:
79            if subTest and self.success:
80                self.result.addSubTest(test_case.test_case, test_case, None)
81        finally:
82            self.success = self.success and old_success
83
84
85def _addSkip(result, test_case, reason):
86    addSkip = getattr(result, 'addSkip', None)
87    if addSkip is not None:
88        addSkip(test_case, reason)
89    else:
90        warnings.warn("TestResult has no addSkip method, skips not reported",
91                      RuntimeWarning, 2)
92        result.addSuccess(test_case)
93
94def _addError(result, test, exc_info):
95    if result is not None and exc_info is not None:
96        if issubclass(exc_info[0], test.failureException):
97            result.addFailure(test, exc_info)
98        else:
99            result.addError(test, exc_info)
100
101def _id(obj):
102    return obj
103
104
105def _enter_context(cm, addcleanup):
106    # We look up the special methods on the type to match the with
107    # statement.
108    cls = type(cm)
109    try:
110        enter = cls.__enter__
111        exit = cls.__exit__
112    except AttributeError:
113        raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
114                        f"not support the context manager protocol") from None
115    result = enter(cm)
116    addcleanup(exit, cm, None, None, None)
117    return result
118
119
120_module_cleanups = []
121def addModuleCleanup(function, /, *args, **kwargs):
122    """Same as addCleanup, except the cleanup items are called even if
123    setUpModule fails (unlike tearDownModule)."""
124    _module_cleanups.append((function, args, kwargs))
125
126def enterModuleContext(cm):
127    """Same as enterContext, but module-wide."""
128    return _enter_context(cm, addModuleCleanup)
129
130
131def doModuleCleanups():
132    """Execute all module cleanup functions. Normally called for you after
133    tearDownModule."""
134    exceptions = []
135    while _module_cleanups:
136        function, args, kwargs = _module_cleanups.pop()
137        try:
138            function(*args, **kwargs)
139        except Exception as exc:
140            exceptions.append(exc)
141    if exceptions:
142        # Swallows all but first exception. If a multi-exception handler
143        # gets written we should use that here instead.
144        raise exceptions[0]
145
146
147def skip(reason):
148    """
149    Unconditionally skip a test.
150    """
151    def decorator(test_item):
152        if not isinstance(test_item, type):
153            @functools.wraps(test_item)
154            def skip_wrapper(*args, **kwargs):
155                raise SkipTest(reason)
156            test_item = skip_wrapper
157
158        test_item.__unittest_skip__ = True
159        test_item.__unittest_skip_why__ = reason
160        return test_item
161    if isinstance(reason, types.FunctionType):
162        test_item = reason
163        reason = ''
164        return decorator(test_item)
165    return decorator
166
167def skipIf(condition, reason):
168    """
169    Skip a test if the condition is true.
170    """
171    if condition:
172        return skip(reason)
173    return _id
174
175def skipUnless(condition, reason):
176    """
177    Skip a test unless the condition is true.
178    """
179    if not condition:
180        return skip(reason)
181    return _id
182
183def expectedFailure(test_item):
184    test_item.__unittest_expecting_failure__ = True
185    return test_item
186
187def _is_subtype(expected, basetype):
188    if isinstance(expected, tuple):
189        return all(_is_subtype(e, basetype) for e in expected)
190    return isinstance(expected, type) and issubclass(expected, basetype)
191
192class _BaseTestCaseContext:
193
194    def __init__(self, test_case):
195        self.test_case = test_case
196
197    def _raiseFailure(self, standardMsg):
198        msg = self.test_case._formatMessage(self.msg, standardMsg)
199        raise self.test_case.failureException(msg)
200
201class _AssertRaisesBaseContext(_BaseTestCaseContext):
202
203    def __init__(self, expected, test_case, expected_regex=None):
204        _BaseTestCaseContext.__init__(self, test_case)
205        self.expected = expected
206        self.test_case = test_case
207        if expected_regex is not None:
208            expected_regex = re.compile(expected_regex)
209        self.expected_regex = expected_regex
210        self.obj_name = None
211        self.msg = None
212
213    def handle(self, name, args, kwargs):
214        """
215        If args is empty, assertRaises/Warns is being used as a
216        context manager, so check for a 'msg' kwarg and return self.
217        If args is not empty, call a callable passing positional and keyword
218        arguments.
219        """
220        try:
221            if not _is_subtype(self.expected, self._base_type):
222                raise TypeError('%s() arg 1 must be %s' %
223                                (name, self._base_type_str))
224            if not args:
225                self.msg = kwargs.pop('msg', None)
226                if kwargs:
227                    raise TypeError('%r is an invalid keyword argument for '
228                                    'this function' % (next(iter(kwargs)),))
229                return self
230
231            callable_obj, *args = args
232            try:
233                self.obj_name = callable_obj.__name__
234            except AttributeError:
235                self.obj_name = str(callable_obj)
236            with self:
237                callable_obj(*args, **kwargs)
238        finally:
239            # bpo-23890: manually break a reference cycle
240            self = None
241
242
243class _AssertRaisesContext(_AssertRaisesBaseContext):
244    """A context manager used to implement TestCase.assertRaises* methods."""
245
246    _base_type = BaseException
247    _base_type_str = 'an exception type or tuple of exception types'
248
249    def __enter__(self):
250        return self
251
252    def __exit__(self, exc_type, exc_value, tb):
253        if exc_type is None:
254            try:
255                exc_name = self.expected.__name__
256            except AttributeError:
257                exc_name = str(self.expected)
258            if self.obj_name:
259                self._raiseFailure("{} not raised by {}".format(exc_name,
260                                                                self.obj_name))
261            else:
262                self._raiseFailure("{} not raised".format(exc_name))
263        else:
264            traceback.clear_frames(tb)
265        if not issubclass(exc_type, self.expected):
266            # let unexpected exceptions pass through
267            return False
268        # store exception, without traceback, for later retrieval
269        self.exception = exc_value.with_traceback(None)
270        if self.expected_regex is None:
271            return True
272
273        expected_regex = self.expected_regex
274        if not expected_regex.search(str(exc_value)):
275            self._raiseFailure('"{}" does not match "{}"'.format(
276                     expected_regex.pattern, str(exc_value)))
277        return True
278
279    __class_getitem__ = classmethod(types.GenericAlias)
280
281
282class _AssertWarnsContext(_AssertRaisesBaseContext):
283    """A context manager used to implement TestCase.assertWarns* methods."""
284
285    _base_type = Warning
286    _base_type_str = 'a warning type or tuple of warning types'
287
288    def __enter__(self):
289        # The __warningregistry__'s need to be in a pristine state for tests
290        # to work properly.
291        for v in list(sys.modules.values()):
292            if getattr(v, '__warningregistry__', None):
293                v.__warningregistry__ = {}
294        self.warnings_manager = warnings.catch_warnings(record=True)
295        self.warnings = self.warnings_manager.__enter__()
296        warnings.simplefilter("always", self.expected)
297        return self
298
299    def __exit__(self, exc_type, exc_value, tb):
300        self.warnings_manager.__exit__(exc_type, exc_value, tb)
301        if exc_type is not None:
302            # let unexpected exceptions pass through
303            return
304        try:
305            exc_name = self.expected.__name__
306        except AttributeError:
307            exc_name = str(self.expected)
308        first_matching = None
309        for m in self.warnings:
310            w = m.message
311            if not isinstance(w, self.expected):
312                continue
313            if first_matching is None:
314                first_matching = w
315            if (self.expected_regex is not None and
316                not self.expected_regex.search(str(w))):
317                continue
318            # store warning for later retrieval
319            self.warning = w
320            self.filename = m.filename
321            self.lineno = m.lineno
322            return
323        # Now we simply try to choose a helpful failure message
324        if first_matching is not None:
325            self._raiseFailure('"{}" does not match "{}"'.format(
326                     self.expected_regex.pattern, str(first_matching)))
327        if self.obj_name:
328            self._raiseFailure("{} not triggered by {}".format(exc_name,
329                                                               self.obj_name))
330        else:
331            self._raiseFailure("{} not triggered".format(exc_name))
332
333
334class _OrderedChainMap(collections.ChainMap):
335    def __iter__(self):
336        seen = set()
337        for mapping in self.maps:
338            for k in mapping:
339                if k not in seen:
340                    seen.add(k)
341                    yield k
342
343
344class TestCase(object):
345    """A class whose instances are single test cases.
346
347    By default, the test code itself should be placed in a method named
348    'runTest'.
349
350    If the fixture may be used for many test cases, create as
351    many test methods as are needed. When instantiating such a TestCase
352    subclass, specify in the constructor arguments the name of the test method
353    that the instance is to execute.
354
355    Test authors should subclass TestCase for their own tests. Construction
356    and deconstruction of the test's environment ('fixture') can be
357    implemented by overriding the 'setUp' and 'tearDown' methods respectively.
358
359    If it is necessary to override the __init__ method, the base class
360    __init__ method must always be called. It is important that subclasses
361    should not change the signature of their __init__ method, since instances
362    of the classes are instantiated automatically by parts of the framework
363    in order to be run.
364
365    When subclassing TestCase, you can set these attributes:
366    * failureException: determines which exception will be raised when
367        the instance's assertion methods fail; test methods raising this
368        exception will be deemed to have 'failed' rather than 'errored'.
369    * longMessage: determines whether long messages (including repr of
370        objects used in assert methods) will be printed on failure in *addition*
371        to any explicit message passed.
372    * maxDiff: sets the maximum length of a diff in failure messages
373        by assert methods using difflib. It is looked up as an instance
374        attribute so can be configured by individual tests if required.
375    """
376
377    failureException = AssertionError
378
379    longMessage = True
380
381    maxDiff = 80*8
382
383    # If a string is longer than _diffThreshold, use normal comparison instead
384    # of difflib.  See #11763.
385    _diffThreshold = 2**16
386
387    def __init_subclass__(cls, *args, **kwargs):
388        # Attribute used by TestSuite for classSetUp
389        cls._classSetupFailed = False
390        cls._class_cleanups = []
391        super().__init_subclass__(*args, **kwargs)
392
393    def __init__(self, methodName='runTest'):
394        """Create an instance of the class that will use the named test
395           method when executed. Raises a ValueError if the instance does
396           not have a method with the specified name.
397        """
398        self._testMethodName = methodName
399        self._outcome = None
400        self._testMethodDoc = 'No test'
401        try:
402            testMethod = getattr(self, methodName)
403        except AttributeError:
404            if methodName != 'runTest':
405                # we allow instantiation with no explicit method name
406                # but not an *incorrect* or missing method name
407                raise ValueError("no such test method in %s: %s" %
408                      (self.__class__, methodName))
409        else:
410            self._testMethodDoc = testMethod.__doc__
411        self._cleanups = []
412        self._subtest = None
413
414        # Map types to custom assertEqual functions that will compare
415        # instances of said type in more detail to generate a more useful
416        # error message.
417        self._type_equality_funcs = {}
418        self.addTypeEqualityFunc(dict, 'assertDictEqual')
419        self.addTypeEqualityFunc(list, 'assertListEqual')
420        self.addTypeEqualityFunc(tuple, 'assertTupleEqual')
421        self.addTypeEqualityFunc(set, 'assertSetEqual')
422        self.addTypeEqualityFunc(frozenset, 'assertSetEqual')
423        self.addTypeEqualityFunc(str, 'assertMultiLineEqual')
424
425    def addTypeEqualityFunc(self, typeobj, function):
426        """Add a type specific assertEqual style function to compare a type.
427
428        This method is for use by TestCase subclasses that need to register
429        their own type equality functions to provide nicer error messages.
430
431        Args:
432            typeobj: The data type to call this function on when both values
433                    are of the same type in assertEqual().
434            function: The callable taking two arguments and an optional
435                    msg= argument that raises self.failureException with a
436                    useful error message when the two arguments are not equal.
437        """
438        self._type_equality_funcs[typeobj] = function
439
440    def addCleanup(self, function, /, *args, **kwargs):
441        """Add a function, with arguments, to be called when the test is
442        completed. Functions added are called on a LIFO basis and are
443        called after tearDown on test failure or success.
444
445        Cleanup items are called even if setUp fails (unlike tearDown)."""
446        self._cleanups.append((function, args, kwargs))
447
448    def enterContext(self, cm):
449        """Enters the supplied context manager.
450
451        If successful, also adds its __exit__ method as a cleanup
452        function and returns the result of the __enter__ method.
453        """
454        return _enter_context(cm, self.addCleanup)
455
456    @classmethod
457    def addClassCleanup(cls, function, /, *args, **kwargs):
458        """Same as addCleanup, except the cleanup items are called even if
459        setUpClass fails (unlike tearDownClass)."""
460        cls._class_cleanups.append((function, args, kwargs))
461
462    @classmethod
463    def enterClassContext(cls, cm):
464        """Same as enterContext, but class-wide."""
465        return _enter_context(cm, cls.addClassCleanup)
466
467    def setUp(self):
468        "Hook method for setting up the test fixture before exercising it."
469        pass
470
471    def tearDown(self):
472        "Hook method for deconstructing the test fixture after testing it."
473        pass
474
475    @classmethod
476    def setUpClass(cls):
477        "Hook method for setting up class fixture before running tests in the class."
478
479    @classmethod
480    def tearDownClass(cls):
481        "Hook method for deconstructing the class fixture after running all tests in the class."
482
483    def countTestCases(self):
484        return 1
485
486    def defaultTestResult(self):
487        return result.TestResult()
488
489    def shortDescription(self):
490        """Returns a one-line description of the test, or None if no
491        description has been provided.
492
493        The default implementation of this method returns the first line of
494        the specified test method's docstring.
495        """
496        doc = self._testMethodDoc
497        return doc.strip().split("\n")[0].strip() if doc else None
498
499
500    def id(self):
501        return "%s.%s" % (strclass(self.__class__), self._testMethodName)
502
503    def __eq__(self, other):
504        if type(self) is not type(other):
505            return NotImplemented
506
507        return self._testMethodName == other._testMethodName
508
509    def __hash__(self):
510        return hash((type(self), self._testMethodName))
511
512    def __str__(self):
513        return "%s (%s.%s)" % (self._testMethodName, strclass(self.__class__), self._testMethodName)
514
515    def __repr__(self):
516        return "<%s testMethod=%s>" % \
517               (strclass(self.__class__), self._testMethodName)
518
519    @contextlib.contextmanager
520    def subTest(self, msg=_subtest_msg_sentinel, **params):
521        """Return a context manager that will return the enclosed block
522        of code in a subtest identified by the optional message and
523        keyword parameters.  A failure in the subtest marks the test
524        case as failed but resumes execution at the end of the enclosed
525        block, allowing further test code to be executed.
526        """
527        if self._outcome is None or not self._outcome.result_supports_subtests:
528            yield
529            return
530        parent = self._subtest
531        if parent is None:
532            params_map = _OrderedChainMap(params)
533        else:
534            params_map = parent.params.new_child(params)
535        self._subtest = _SubTest(self, msg, params_map)
536        try:
537            with self._outcome.testPartExecutor(self._subtest, subTest=True):
538                yield
539            if not self._outcome.success:
540                result = self._outcome.result
541                if result is not None and result.failfast:
542                    raise _ShouldStop
543            elif self._outcome.expectedFailure:
544                # If the test is expecting a failure, we really want to
545                # stop now and register the expected failure.
546                raise _ShouldStop
547        finally:
548            self._subtest = parent
549
550    def _addExpectedFailure(self, result, exc_info):
551        try:
552            addExpectedFailure = result.addExpectedFailure
553        except AttributeError:
554            warnings.warn("TestResult has no addExpectedFailure method, reporting as passes",
555                          RuntimeWarning)
556            result.addSuccess(self)
557        else:
558            addExpectedFailure(self, exc_info)
559
560    def _addUnexpectedSuccess(self, result):
561        try:
562            addUnexpectedSuccess = result.addUnexpectedSuccess
563        except AttributeError:
564            warnings.warn("TestResult has no addUnexpectedSuccess method, reporting as failure",
565                          RuntimeWarning)
566            # We need to pass an actual exception and traceback to addFailure,
567            # otherwise the legacy result can choke.
568            try:
569                raise _UnexpectedSuccess from None
570            except _UnexpectedSuccess:
571                result.addFailure(self, sys.exc_info())
572        else:
573            addUnexpectedSuccess(self)
574
575    def _callSetUp(self):
576        self.setUp()
577
578    def _callTestMethod(self, method):
579        if method() is not None:
580            warnings.warn(f'It is deprecated to return a value that is not None from a '
581                          f'test case ({method})', DeprecationWarning, stacklevel=3)
582
583    def _callTearDown(self):
584        self.tearDown()
585
586    def _callCleanup(self, function, /, *args, **kwargs):
587        function(*args, **kwargs)
588
589    def run(self, result=None):
590        if result is None:
591            result = self.defaultTestResult()
592            startTestRun = getattr(result, 'startTestRun', None)
593            stopTestRun = getattr(result, 'stopTestRun', None)
594            if startTestRun is not None:
595                startTestRun()
596        else:
597            stopTestRun = None
598
599        result.startTest(self)
600        try:
601            testMethod = getattr(self, self._testMethodName)
602            if (getattr(self.__class__, "__unittest_skip__", False) or
603                getattr(testMethod, "__unittest_skip__", False)):
604                # If the class or method was skipped.
605                skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
606                            or getattr(testMethod, '__unittest_skip_why__', ''))
607                _addSkip(result, self, skip_why)
608                return result
609
610            expecting_failure = (
611                getattr(self, "__unittest_expecting_failure__", False) or
612                getattr(testMethod, "__unittest_expecting_failure__", False)
613            )
614            outcome = _Outcome(result)
615            try:
616                self._outcome = outcome
617
618                with outcome.testPartExecutor(self):
619                    self._callSetUp()
620                if outcome.success:
621                    outcome.expecting_failure = expecting_failure
622                    with outcome.testPartExecutor(self):
623                        self._callTestMethod(testMethod)
624                    outcome.expecting_failure = False
625                    with outcome.testPartExecutor(self):
626                        self._callTearDown()
627                self.doCleanups()
628
629                if outcome.success:
630                    if expecting_failure:
631                        if outcome.expectedFailure:
632                            self._addExpectedFailure(result, outcome.expectedFailure)
633                        else:
634                            self._addUnexpectedSuccess(result)
635                    else:
636                        result.addSuccess(self)
637                return result
638            finally:
639                # explicitly break reference cycle:
640                # outcome.expectedFailure -> frame -> outcome -> outcome.expectedFailure
641                outcome.expectedFailure = None
642                outcome = None
643
644                # clear the outcome, no more needed
645                self._outcome = None
646
647        finally:
648            result.stopTest(self)
649            if stopTestRun is not None:
650                stopTestRun()
651
652    def doCleanups(self):
653        """Execute all cleanup functions. Normally called for you after
654        tearDown."""
655        outcome = self._outcome or _Outcome()
656        while self._cleanups:
657            function, args, kwargs = self._cleanups.pop()
658            with outcome.testPartExecutor(self):
659                self._callCleanup(function, *args, **kwargs)
660
661        # return this for backwards compatibility
662        # even though we no longer use it internally
663        return outcome.success
664
665    @classmethod
666    def doClassCleanups(cls):
667        """Execute all class cleanup functions. Normally called for you after
668        tearDownClass."""
669        cls.tearDown_exceptions = []
670        while cls._class_cleanups:
671            function, args, kwargs = cls._class_cleanups.pop()
672            try:
673                function(*args, **kwargs)
674            except Exception:
675                cls.tearDown_exceptions.append(sys.exc_info())
676
677    def __call__(self, *args, **kwds):
678        return self.run(*args, **kwds)
679
680    def debug(self):
681        """Run the test without collecting errors in a TestResult"""
682        testMethod = getattr(self, self._testMethodName)
683        if (getattr(self.__class__, "__unittest_skip__", False) or
684            getattr(testMethod, "__unittest_skip__", False)):
685            # If the class or method was skipped.
686            skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
687                        or getattr(testMethod, '__unittest_skip_why__', ''))
688            raise SkipTest(skip_why)
689
690        self._callSetUp()
691        self._callTestMethod(testMethod)
692        self._callTearDown()
693        while self._cleanups:
694            function, args, kwargs = self._cleanups.pop()
695            self._callCleanup(function, *args, **kwargs)
696
697    def skipTest(self, reason):
698        """Skip this test."""
699        raise SkipTest(reason)
700
701    def fail(self, msg=None):
702        """Fail immediately, with the given message."""
703        raise self.failureException(msg)
704
705    def assertFalse(self, expr, msg=None):
706        """Check that the expression is false."""
707        if expr:
708            msg = self._formatMessage(msg, "%s is not false" % safe_repr(expr))
709            raise self.failureException(msg)
710
711    def assertTrue(self, expr, msg=None):
712        """Check that the expression is true."""
713        if not expr:
714            msg = self._formatMessage(msg, "%s is not true" % safe_repr(expr))
715            raise self.failureException(msg)
716
717    def _formatMessage(self, msg, standardMsg):
718        """Honour the longMessage attribute when generating failure messages.
719        If longMessage is False this means:
720        * Use only an explicit message if it is provided
721        * Otherwise use the standard message for the assert
722
723        If longMessage is True:
724        * Use the standard message
725        * If an explicit message is provided, plus ' : ' and the explicit message
726        """
727        if not self.longMessage:
728            return msg or standardMsg
729        if msg is None:
730            return standardMsg
731        try:
732            # don't switch to '{}' formatting in Python 2.X
733            # it changes the way unicode input is handled
734            return '%s : %s' % (standardMsg, msg)
735        except UnicodeDecodeError:
736            return  '%s : %s' % (safe_repr(standardMsg), safe_repr(msg))
737
738    def assertRaises(self, expected_exception, *args, **kwargs):
739        """Fail unless an exception of class expected_exception is raised
740           by the callable when invoked with specified positional and
741           keyword arguments. If a different type of exception is
742           raised, it will not be caught, and the test case will be
743           deemed to have suffered an error, exactly as for an
744           unexpected exception.
745
746           If called with the callable and arguments omitted, will return a
747           context object used like this::
748
749                with self.assertRaises(SomeException):
750                    do_something()
751
752           An optional keyword argument 'msg' can be provided when assertRaises
753           is used as a context object.
754
755           The context manager keeps a reference to the exception as
756           the 'exception' attribute. This allows you to inspect the
757           exception after the assertion::
758
759               with self.assertRaises(SomeException) as cm:
760                   do_something()
761               the_exception = cm.exception
762               self.assertEqual(the_exception.error_code, 3)
763        """
764        context = _AssertRaisesContext(expected_exception, self)
765        try:
766            return context.handle('assertRaises', args, kwargs)
767        finally:
768            # bpo-23890: manually break a reference cycle
769            context = None
770
771    def assertWarns(self, expected_warning, *args, **kwargs):
772        """Fail unless a warning of class warnClass is triggered
773           by the callable when invoked with specified positional and
774           keyword arguments.  If a different type of warning is
775           triggered, it will not be handled: depending on the other
776           warning filtering rules in effect, it might be silenced, printed
777           out, or raised as an exception.
778
779           If called with the callable and arguments omitted, will return a
780           context object used like this::
781
782                with self.assertWarns(SomeWarning):
783                    do_something()
784
785           An optional keyword argument 'msg' can be provided when assertWarns
786           is used as a context object.
787
788           The context manager keeps a reference to the first matching
789           warning as the 'warning' attribute; similarly, the 'filename'
790           and 'lineno' attributes give you information about the line
791           of Python code from which the warning was triggered.
792           This allows you to inspect the warning after the assertion::
793
794               with self.assertWarns(SomeWarning) as cm:
795                   do_something()
796               the_warning = cm.warning
797               self.assertEqual(the_warning.some_attribute, 147)
798        """
799        context = _AssertWarnsContext(expected_warning, self)
800        return context.handle('assertWarns', args, kwargs)
801
802    def assertLogs(self, logger=None, level=None):
803        """Fail unless a log message of level *level* or higher is emitted
804        on *logger_name* or its children.  If omitted, *level* defaults to
805        INFO and *logger* defaults to the root logger.
806
807        This method must be used as a context manager, and will yield
808        a recording object with two attributes: `output` and `records`.
809        At the end of the context manager, the `output` attribute will
810        be a list of the matching formatted log messages and the
811        `records` attribute will be a list of the corresponding LogRecord
812        objects.
813
814        Example::
815
816            with self.assertLogs('foo', level='INFO') as cm:
817                logging.getLogger('foo').info('first message')
818                logging.getLogger('foo.bar').error('second message')
819            self.assertEqual(cm.output, ['INFO:foo:first message',
820                                         'ERROR:foo.bar:second message'])
821        """
822        # Lazy import to avoid importing logging if it is not needed.
823        from ._log import _AssertLogsContext
824        return _AssertLogsContext(self, logger, level, no_logs=False)
825
826    def assertNoLogs(self, logger=None, level=None):
827        """ Fail unless no log messages of level *level* or higher are emitted
828        on *logger_name* or its children.
829
830        This method must be used as a context manager.
831        """
832        from ._log import _AssertLogsContext
833        return _AssertLogsContext(self, logger, level, no_logs=True)
834
835    def _getAssertEqualityFunc(self, first, second):
836        """Get a detailed comparison function for the types of the two args.
837
838        Returns: A callable accepting (first, second, msg=None) that will
839        raise a failure exception if first != second with a useful human
840        readable error message for those types.
841        """
842        #
843        # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
844        # and vice versa.  I opted for the conservative approach in case
845        # subclasses are not intended to be compared in detail to their super
846        # class instances using a type equality func.  This means testing
847        # subtypes won't automagically use the detailed comparison.  Callers
848        # should use their type specific assertSpamEqual method to compare
849        # subclasses if the detailed comparison is desired and appropriate.
850        # See the discussion in http://bugs.python.org/issue2578.
851        #
852        if type(first) is type(second):
853            asserter = self._type_equality_funcs.get(type(first))
854            if asserter is not None:
855                if isinstance(asserter, str):
856                    asserter = getattr(self, asserter)
857                return asserter
858
859        return self._baseAssertEqual
860
861    def _baseAssertEqual(self, first, second, msg=None):
862        """The default assertEqual implementation, not type specific."""
863        if not first == second:
864            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
865            msg = self._formatMessage(msg, standardMsg)
866            raise self.failureException(msg)
867
868    def assertEqual(self, first, second, msg=None):
869        """Fail if the two objects are unequal as determined by the '=='
870           operator.
871        """
872        assertion_func = self._getAssertEqualityFunc(first, second)
873        assertion_func(first, second, msg=msg)
874
875    def assertNotEqual(self, first, second, msg=None):
876        """Fail if the two objects are equal as determined by the '!='
877           operator.
878        """
879        if not first != second:
880            msg = self._formatMessage(msg, '%s == %s' % (safe_repr(first),
881                                                          safe_repr(second)))
882            raise self.failureException(msg)
883
884    def assertAlmostEqual(self, first, second, places=None, msg=None,
885                          delta=None):
886        """Fail if the two objects are unequal as determined by their
887           difference rounded to the given number of decimal places
888           (default 7) and comparing to zero, or by comparing that the
889           difference between the two objects is more than the given
890           delta.
891
892           Note that decimal places (from zero) are usually not the same
893           as significant digits (measured from the most significant digit).
894
895           If the two objects compare equal then they will automatically
896           compare almost equal.
897        """
898        if first == second:
899            # shortcut
900            return
901        if delta is not None and places is not None:
902            raise TypeError("specify delta or places not both")
903
904        diff = abs(first - second)
905        if delta is not None:
906            if diff <= delta:
907                return
908
909            standardMsg = '%s != %s within %s delta (%s difference)' % (
910                safe_repr(first),
911                safe_repr(second),
912                safe_repr(delta),
913                safe_repr(diff))
914        else:
915            if places is None:
916                places = 7
917
918            if round(diff, places) == 0:
919                return
920
921            standardMsg = '%s != %s within %r places (%s difference)' % (
922                safe_repr(first),
923                safe_repr(second),
924                places,
925                safe_repr(diff))
926        msg = self._formatMessage(msg, standardMsg)
927        raise self.failureException(msg)
928
929    def assertNotAlmostEqual(self, first, second, places=None, msg=None,
930                             delta=None):
931        """Fail if the two objects are equal as determined by their
932           difference rounded to the given number of decimal places
933           (default 7) and comparing to zero, or by comparing that the
934           difference between the two objects is less than the given delta.
935
936           Note that decimal places (from zero) are usually not the same
937           as significant digits (measured from the most significant digit).
938
939           Objects that are equal automatically fail.
940        """
941        if delta is not None and places is not None:
942            raise TypeError("specify delta or places not both")
943        diff = abs(first - second)
944        if delta is not None:
945            if not (first == second) and diff > delta:
946                return
947            standardMsg = '%s == %s within %s delta (%s difference)' % (
948                safe_repr(first),
949                safe_repr(second),
950                safe_repr(delta),
951                safe_repr(diff))
952        else:
953            if places is None:
954                places = 7
955            if not (first == second) and round(diff, places) != 0:
956                return
957            standardMsg = '%s == %s within %r places' % (safe_repr(first),
958                                                         safe_repr(second),
959                                                         places)
960
961        msg = self._formatMessage(msg, standardMsg)
962        raise self.failureException(msg)
963
964    def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
965        """An equality assertion for ordered sequences (like lists and tuples).
966
967        For the purposes of this function, a valid ordered sequence type is one
968        which can be indexed, has a length, and has an equality operator.
969
970        Args:
971            seq1: The first sequence to compare.
972            seq2: The second sequence to compare.
973            seq_type: The expected datatype of the sequences, or None if no
974                    datatype should be enforced.
975            msg: Optional message to use on failure instead of a list of
976                    differences.
977        """
978        if seq_type is not None:
979            seq_type_name = seq_type.__name__
980            if not isinstance(seq1, seq_type):
981                raise self.failureException('First sequence is not a %s: %s'
982                                        % (seq_type_name, safe_repr(seq1)))
983            if not isinstance(seq2, seq_type):
984                raise self.failureException('Second sequence is not a %s: %s'
985                                        % (seq_type_name, safe_repr(seq2)))
986        else:
987            seq_type_name = "sequence"
988
989        differing = None
990        try:
991            len1 = len(seq1)
992        except (TypeError, NotImplementedError):
993            differing = 'First %s has no length.    Non-sequence?' % (
994                    seq_type_name)
995
996        if differing is None:
997            try:
998                len2 = len(seq2)
999            except (TypeError, NotImplementedError):
1000                differing = 'Second %s has no length.    Non-sequence?' % (
1001                        seq_type_name)
1002
1003        if differing is None:
1004            if seq1 == seq2:
1005                return
1006
1007            differing = '%ss differ: %s != %s\n' % (
1008                    (seq_type_name.capitalize(),) +
1009                    _common_shorten_repr(seq1, seq2))
1010
1011            for i in range(min(len1, len2)):
1012                try:
1013                    item1 = seq1[i]
1014                except (TypeError, IndexError, NotImplementedError):
1015                    differing += ('\nUnable to index element %d of first %s\n' %
1016                                 (i, seq_type_name))
1017                    break
1018
1019                try:
1020                    item2 = seq2[i]
1021                except (TypeError, IndexError, NotImplementedError):
1022                    differing += ('\nUnable to index element %d of second %s\n' %
1023                                 (i, seq_type_name))
1024                    break
1025
1026                if item1 != item2:
1027                    differing += ('\nFirst differing element %d:\n%s\n%s\n' %
1028                                 ((i,) + _common_shorten_repr(item1, item2)))
1029                    break
1030            else:
1031                if (len1 == len2 and seq_type is None and
1032                    type(seq1) != type(seq2)):
1033                    # The sequences are the same, but have differing types.
1034                    return
1035
1036            if len1 > len2:
1037                differing += ('\nFirst %s contains %d additional '
1038                             'elements.\n' % (seq_type_name, len1 - len2))
1039                try:
1040                    differing += ('First extra element %d:\n%s\n' %
1041                                  (len2, safe_repr(seq1[len2])))
1042                except (TypeError, IndexError, NotImplementedError):
1043                    differing += ('Unable to index element %d '
1044                                  'of first %s\n' % (len2, seq_type_name))
1045            elif len1 < len2:
1046                differing += ('\nSecond %s contains %d additional '
1047                             'elements.\n' % (seq_type_name, len2 - len1))
1048                try:
1049                    differing += ('First extra element %d:\n%s\n' %
1050                                  (len1, safe_repr(seq2[len1])))
1051                except (TypeError, IndexError, NotImplementedError):
1052                    differing += ('Unable to index element %d '
1053                                  'of second %s\n' % (len1, seq_type_name))
1054        standardMsg = differing
1055        diffMsg = '\n' + '\n'.join(
1056            difflib.ndiff(pprint.pformat(seq1).splitlines(),
1057                          pprint.pformat(seq2).splitlines()))
1058
1059        standardMsg = self._truncateMessage(standardMsg, diffMsg)
1060        msg = self._formatMessage(msg, standardMsg)
1061        self.fail(msg)
1062
1063    def _truncateMessage(self, message, diff):
1064        max_diff = self.maxDiff
1065        if max_diff is None or len(diff) <= max_diff:
1066            return message + diff
1067        return message + (DIFF_OMITTED % len(diff))
1068
1069    def assertListEqual(self, list1, list2, msg=None):
1070        """A list-specific equality assertion.
1071
1072        Args:
1073            list1: The first list to compare.
1074            list2: The second list to compare.
1075            msg: Optional message to use on failure instead of a list of
1076                    differences.
1077
1078        """
1079        self.assertSequenceEqual(list1, list2, msg, seq_type=list)
1080
1081    def assertTupleEqual(self, tuple1, tuple2, msg=None):
1082        """A tuple-specific equality assertion.
1083
1084        Args:
1085            tuple1: The first tuple to compare.
1086            tuple2: The second tuple to compare.
1087            msg: Optional message to use on failure instead of a list of
1088                    differences.
1089        """
1090        self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
1091
1092    def assertSetEqual(self, set1, set2, msg=None):
1093        """A set-specific equality assertion.
1094
1095        Args:
1096            set1: The first set to compare.
1097            set2: The second set to compare.
1098            msg: Optional message to use on failure instead of a list of
1099                    differences.
1100
1101        assertSetEqual uses ducktyping to support different types of sets, and
1102        is optimized for sets specifically (parameters must support a
1103        difference method).
1104        """
1105        try:
1106            difference1 = set1.difference(set2)
1107        except TypeError as e:
1108            self.fail('invalid type when attempting set difference: %s' % e)
1109        except AttributeError as e:
1110            self.fail('first argument does not support set difference: %s' % e)
1111
1112        try:
1113            difference2 = set2.difference(set1)
1114        except TypeError as e:
1115            self.fail('invalid type when attempting set difference: %s' % e)
1116        except AttributeError as e:
1117            self.fail('second argument does not support set difference: %s' % e)
1118
1119        if not (difference1 or difference2):
1120            return
1121
1122        lines = []
1123        if difference1:
1124            lines.append('Items in the first set but not the second:')
1125            for item in difference1:
1126                lines.append(repr(item))
1127        if difference2:
1128            lines.append('Items in the second set but not the first:')
1129            for item in difference2:
1130                lines.append(repr(item))
1131
1132        standardMsg = '\n'.join(lines)
1133        self.fail(self._formatMessage(msg, standardMsg))
1134
1135    def assertIn(self, member, container, msg=None):
1136        """Just like self.assertTrue(a in b), but with a nicer default message."""
1137        if member not in container:
1138            standardMsg = '%s not found in %s' % (safe_repr(member),
1139                                                  safe_repr(container))
1140            self.fail(self._formatMessage(msg, standardMsg))
1141
1142    def assertNotIn(self, member, container, msg=None):
1143        """Just like self.assertTrue(a not in b), but with a nicer default message."""
1144        if member in container:
1145            standardMsg = '%s unexpectedly found in %s' % (safe_repr(member),
1146                                                        safe_repr(container))
1147            self.fail(self._formatMessage(msg, standardMsg))
1148
1149    def assertIs(self, expr1, expr2, msg=None):
1150        """Just like self.assertTrue(a is b), but with a nicer default message."""
1151        if expr1 is not expr2:
1152            standardMsg = '%s is not %s' % (safe_repr(expr1),
1153                                             safe_repr(expr2))
1154            self.fail(self._formatMessage(msg, standardMsg))
1155
1156    def assertIsNot(self, expr1, expr2, msg=None):
1157        """Just like self.assertTrue(a is not b), but with a nicer default message."""
1158        if expr1 is expr2:
1159            standardMsg = 'unexpectedly identical: %s' % (safe_repr(expr1),)
1160            self.fail(self._formatMessage(msg, standardMsg))
1161
1162    def assertDictEqual(self, d1, d2, msg=None):
1163        self.assertIsInstance(d1, dict, 'First argument is not a dictionary')
1164        self.assertIsInstance(d2, dict, 'Second argument is not a dictionary')
1165
1166        if d1 != d2:
1167            standardMsg = '%s != %s' % _common_shorten_repr(d1, d2)
1168            diff = ('\n' + '\n'.join(difflib.ndiff(
1169                           pprint.pformat(d1).splitlines(),
1170                           pprint.pformat(d2).splitlines())))
1171            standardMsg = self._truncateMessage(standardMsg, diff)
1172            self.fail(self._formatMessage(msg, standardMsg))
1173
1174    def assertDictContainsSubset(self, subset, dictionary, msg=None):
1175        """Checks whether dictionary is a superset of subset."""
1176        warnings.warn('assertDictContainsSubset is deprecated',
1177                      DeprecationWarning)
1178        missing = []
1179        mismatched = []
1180        for key, value in subset.items():
1181            if key not in dictionary:
1182                missing.append(key)
1183            elif value != dictionary[key]:
1184                mismatched.append('%s, expected: %s, actual: %s' %
1185                                  (safe_repr(key), safe_repr(value),
1186                                   safe_repr(dictionary[key])))
1187
1188        if not (missing or mismatched):
1189            return
1190
1191        standardMsg = ''
1192        if missing:
1193            standardMsg = 'Missing: %s' % ','.join(safe_repr(m) for m in
1194                                                    missing)
1195        if mismatched:
1196            if standardMsg:
1197                standardMsg += '; '
1198            standardMsg += 'Mismatched values: %s' % ','.join(mismatched)
1199
1200        self.fail(self._formatMessage(msg, standardMsg))
1201
1202
1203    def assertCountEqual(self, first, second, msg=None):
1204        """Asserts that two iterables have the same elements, the same number of
1205        times, without regard to order.
1206
1207            self.assertEqual(Counter(list(first)),
1208                             Counter(list(second)))
1209
1210         Example:
1211            - [0, 1, 1] and [1, 0, 1] compare equal.
1212            - [0, 0, 1] and [0, 1] compare unequal.
1213
1214        """
1215        first_seq, second_seq = list(first), list(second)
1216        try:
1217            first = collections.Counter(first_seq)
1218            second = collections.Counter(second_seq)
1219        except TypeError:
1220            # Handle case with unhashable elements
1221            differences = _count_diff_all_purpose(first_seq, second_seq)
1222        else:
1223            if first == second:
1224                return
1225            differences = _count_diff_hashable(first_seq, second_seq)
1226
1227        if differences:
1228            standardMsg = 'Element counts were not equal:\n'
1229            lines = ['First has %d, Second has %d:  %r' % diff for diff in differences]
1230            diffMsg = '\n'.join(lines)
1231            standardMsg = self._truncateMessage(standardMsg, diffMsg)
1232            msg = self._formatMessage(msg, standardMsg)
1233            self.fail(msg)
1234
1235    def assertMultiLineEqual(self, first, second, msg=None):
1236        """Assert that two multi-line strings are equal."""
1237        self.assertIsInstance(first, str, 'First argument is not a string')
1238        self.assertIsInstance(second, str, 'Second argument is not a string')
1239
1240        if first != second:
1241            # don't use difflib if the strings are too long
1242            if (len(first) > self._diffThreshold or
1243                len(second) > self._diffThreshold):
1244                self._baseAssertEqual(first, second, msg)
1245            firstlines = first.splitlines(keepends=True)
1246            secondlines = second.splitlines(keepends=True)
1247            if len(firstlines) == 1 and first.strip('\r\n') == first:
1248                firstlines = [first + '\n']
1249                secondlines = [second + '\n']
1250            standardMsg = '%s != %s' % _common_shorten_repr(first, second)
1251            diff = '\n' + ''.join(difflib.ndiff(firstlines, secondlines))
1252            standardMsg = self._truncateMessage(standardMsg, diff)
1253            self.fail(self._formatMessage(msg, standardMsg))
1254
1255    def assertLess(self, a, b, msg=None):
1256        """Just like self.assertTrue(a < b), but with a nicer default message."""
1257        if not a < b:
1258            standardMsg = '%s not less than %s' % (safe_repr(a), safe_repr(b))
1259            self.fail(self._formatMessage(msg, standardMsg))
1260
1261    def assertLessEqual(self, a, b, msg=None):
1262        """Just like self.assertTrue(a <= b), but with a nicer default message."""
1263        if not a <= b:
1264            standardMsg = '%s not less than or equal to %s' % (safe_repr(a), safe_repr(b))
1265            self.fail(self._formatMessage(msg, standardMsg))
1266
1267    def assertGreater(self, a, b, msg=None):
1268        """Just like self.assertTrue(a > b), but with a nicer default message."""
1269        if not a > b:
1270            standardMsg = '%s not greater than %s' % (safe_repr(a), safe_repr(b))
1271            self.fail(self._formatMessage(msg, standardMsg))
1272
1273    def assertGreaterEqual(self, a, b, msg=None):
1274        """Just like self.assertTrue(a >= b), but with a nicer default message."""
1275        if not a >= b:
1276            standardMsg = '%s not greater than or equal to %s' % (safe_repr(a), safe_repr(b))
1277            self.fail(self._formatMessage(msg, standardMsg))
1278
1279    def assertIsNone(self, obj, msg=None):
1280        """Same as self.assertTrue(obj is None), with a nicer default message."""
1281        if obj is not None:
1282            standardMsg = '%s is not None' % (safe_repr(obj),)
1283            self.fail(self._formatMessage(msg, standardMsg))
1284
1285    def assertIsNotNone(self, obj, msg=None):
1286        """Included for symmetry with assertIsNone."""
1287        if obj is None:
1288            standardMsg = 'unexpectedly None'
1289            self.fail(self._formatMessage(msg, standardMsg))
1290
1291    def assertIsInstance(self, obj, cls, msg=None):
1292        """Same as self.assertTrue(isinstance(obj, cls)), with a nicer
1293        default message."""
1294        if not isinstance(obj, cls):
1295            standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
1296            self.fail(self._formatMessage(msg, standardMsg))
1297
1298    def assertNotIsInstance(self, obj, cls, msg=None):
1299        """Included for symmetry with assertIsInstance."""
1300        if isinstance(obj, cls):
1301            standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
1302            self.fail(self._formatMessage(msg, standardMsg))
1303
1304    def assertRaisesRegex(self, expected_exception, expected_regex,
1305                          *args, **kwargs):
1306        """Asserts that the message in a raised exception matches a regex.
1307
1308        Args:
1309            expected_exception: Exception class expected to be raised.
1310            expected_regex: Regex (re.Pattern object or string) expected
1311                    to be found in error message.
1312            args: Function to be called and extra positional args.
1313            kwargs: Extra kwargs.
1314            msg: Optional message used in case of failure. Can only be used
1315                    when assertRaisesRegex is used as a context manager.
1316        """
1317        context = _AssertRaisesContext(expected_exception, self, expected_regex)
1318        return context.handle('assertRaisesRegex', args, kwargs)
1319
1320    def assertWarnsRegex(self, expected_warning, expected_regex,
1321                         *args, **kwargs):
1322        """Asserts that the message in a triggered warning matches a regexp.
1323        Basic functioning is similar to assertWarns() with the addition
1324        that only warnings whose messages also match the regular expression
1325        are considered successful matches.
1326
1327        Args:
1328            expected_warning: Warning class expected to be triggered.
1329            expected_regex: Regex (re.Pattern object or string) expected
1330                    to be found in error message.
1331            args: Function to be called and extra positional args.
1332            kwargs: Extra kwargs.
1333            msg: Optional message used in case of failure. Can only be used
1334                    when assertWarnsRegex is used as a context manager.
1335        """
1336        context = _AssertWarnsContext(expected_warning, self, expected_regex)
1337        return context.handle('assertWarnsRegex', args, kwargs)
1338
1339    def assertRegex(self, text, expected_regex, msg=None):
1340        """Fail the test unless the text matches the regular expression."""
1341        if isinstance(expected_regex, (str, bytes)):
1342            assert expected_regex, "expected_regex must not be empty."
1343            expected_regex = re.compile(expected_regex)
1344        if not expected_regex.search(text):
1345            standardMsg = "Regex didn't match: %r not found in %r" % (
1346                expected_regex.pattern, text)
1347            # _formatMessage ensures the longMessage option is respected
1348            msg = self._formatMessage(msg, standardMsg)
1349            raise self.failureException(msg)
1350
1351    def assertNotRegex(self, text, unexpected_regex, msg=None):
1352        """Fail the test if the text matches the regular expression."""
1353        if isinstance(unexpected_regex, (str, bytes)):
1354            unexpected_regex = re.compile(unexpected_regex)
1355        match = unexpected_regex.search(text)
1356        if match:
1357            standardMsg = 'Regex matched: %r matches %r in %r' % (
1358                text[match.start() : match.end()],
1359                unexpected_regex.pattern,
1360                text)
1361            # _formatMessage ensures the longMessage option is respected
1362            msg = self._formatMessage(msg, standardMsg)
1363            raise self.failureException(msg)
1364
1365
1366    def _deprecate(original_func):
1367        def deprecated_func(*args, **kwargs):
1368            warnings.warn(
1369                'Please use {0} instead.'.format(original_func.__name__),
1370                DeprecationWarning, 2)
1371            return original_func(*args, **kwargs)
1372        return deprecated_func
1373
1374    # see #9424
1375    failUnlessEqual = assertEquals = _deprecate(assertEqual)
1376    failIfEqual = assertNotEquals = _deprecate(assertNotEqual)
1377    failUnlessAlmostEqual = assertAlmostEquals = _deprecate(assertAlmostEqual)
1378    failIfAlmostEqual = assertNotAlmostEquals = _deprecate(assertNotAlmostEqual)
1379    failUnless = assert_ = _deprecate(assertTrue)
1380    failUnlessRaises = _deprecate(assertRaises)
1381    failIf = _deprecate(assertFalse)
1382    assertRaisesRegexp = _deprecate(assertRaisesRegex)
1383    assertRegexpMatches = _deprecate(assertRegex)
1384    assertNotRegexpMatches = _deprecate(assertNotRegex)
1385
1386
1387
1388class FunctionTestCase(TestCase):
1389    """A test case that wraps a test function.
1390
1391    This is useful for slipping pre-existing test functions into the
1392    unittest framework. Optionally, set-up and tidy-up functions can be
1393    supplied. As with TestCase, the tidy-up ('tearDown') function will
1394    always be called if the set-up ('setUp') function ran successfully.
1395    """
1396
1397    def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
1398        super(FunctionTestCase, self).__init__()
1399        self._setUpFunc = setUp
1400        self._tearDownFunc = tearDown
1401        self._testFunc = testFunc
1402        self._description = description
1403
1404    def setUp(self):
1405        if self._setUpFunc is not None:
1406            self._setUpFunc()
1407
1408    def tearDown(self):
1409        if self._tearDownFunc is not None:
1410            self._tearDownFunc()
1411
1412    def runTest(self):
1413        self._testFunc()
1414
1415    def id(self):
1416        return self._testFunc.__name__
1417
1418    def __eq__(self, other):
1419        if not isinstance(other, self.__class__):
1420            return NotImplemented
1421
1422        return self._setUpFunc == other._setUpFunc and \
1423               self._tearDownFunc == other._tearDownFunc and \
1424               self._testFunc == other._testFunc and \
1425               self._description == other._description
1426
1427    def __hash__(self):
1428        return hash((type(self), self._setUpFunc, self._tearDownFunc,
1429                     self._testFunc, self._description))
1430
1431    def __str__(self):
1432        return "%s (%s)" % (strclass(self.__class__),
1433                            self._testFunc.__name__)
1434
1435    def __repr__(self):
1436        return "<%s tec=%s>" % (strclass(self.__class__),
1437                                     self._testFunc)
1438
1439    def shortDescription(self):
1440        if self._description is not None:
1441            return self._description
1442        doc = self._testFunc.__doc__
1443        return doc and doc.split("\n")[0].strip() or None
1444
1445
1446class _SubTest(TestCase):
1447
1448    def __init__(self, test_case, message, params):
1449        super().__init__()
1450        self._message = message
1451        self.test_case = test_case
1452        self.params = params
1453        self.failureException = test_case.failureException
1454
1455    def runTest(self):
1456        raise NotImplementedError("subtests cannot be run directly")
1457
1458    def _subDescription(self):
1459        parts = []
1460        if self._message is not _subtest_msg_sentinel:
1461            parts.append("[{}]".format(self._message))
1462        if self.params:
1463            params_desc = ', '.join(
1464                "{}={!r}".format(k, v)
1465                for (k, v) in self.params.items())
1466            parts.append("({})".format(params_desc))
1467        return " ".join(parts) or '(<subtest>)'
1468
1469    def id(self):
1470        return "{} {}".format(self.test_case.id(), self._subDescription())
1471
1472    def shortDescription(self):
1473        """Returns a one-line description of the subtest, or None if no
1474        description has been provided.
1475        """
1476        return self.test_case.shortDescription()
1477
1478    def __str__(self):
1479        return "{} {}".format(self.test_case, self._subDescription())
1480