1"""Loading unittests."""
2
3import os
4import re
5import sys
6import traceback
7import types
8import functools
9import warnings
10
11from fnmatch import fnmatch, fnmatchcase
12
13from . import case, suite, util
14
15__unittest = True
16
17# what about .pyc (etc)
18# we would need to avoid loading the same tests multiple times
19# from '.py', *and* '.pyc'
20VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
21
22
23class _FailedTest(case.TestCase):
24    _testMethodName = None
25
26    def __init__(self, method_name, exception):
27        self._exception = exception
28        super(_FailedTest, self).__init__(method_name)
29
30    def __getattr__(self, name):
31        if name != self._testMethodName:
32            return super(_FailedTest, self).__getattr__(name)
33        def testFailure():
34            raise self._exception
35        return testFailure
36
37
38def _make_failed_import_test(name, suiteClass):
39    message = 'Failed to import test module: %s\n%s' % (
40        name, traceback.format_exc())
41    return _make_failed_test(name, ImportError(message), suiteClass, message)
42
43def _make_failed_load_tests(name, exception, suiteClass):
44    message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),)
45    return _make_failed_test(
46        name, exception, suiteClass, message)
47
48def _make_failed_test(methodname, exception, suiteClass, message):
49    test = _FailedTest(methodname, exception)
50    return suiteClass((test,)), message
51
52def _make_skipped_test(methodname, exception, suiteClass):
53    @case.skip(str(exception))
54    def testSkipped(self):
55        pass
56    attrs = {methodname: testSkipped}
57    TestClass = type("ModuleSkipped", (case.TestCase,), attrs)
58    return suiteClass((TestClass(methodname),))
59
60def _jython_aware_splitext(path):
61    if path.lower().endswith('$py.class'):
62        return path[:-9]
63    return os.path.splitext(path)[0]
64
65
66class TestLoader(object):
67    """
68    This class is responsible for loading tests according to various criteria
69    and returning them wrapped in a TestSuite
70    """
71    testMethodPrefix = 'test'
72    sortTestMethodsUsing = staticmethod(util.three_way_cmp)
73    testNamePatterns = None
74    suiteClass = suite.TestSuite
75    _top_level_dir = None
76
77    def __init__(self):
78        super(TestLoader, self).__init__()
79        self.errors = []
80        # Tracks packages which we have called into via load_tests, to
81        # avoid infinite re-entrancy.
82        self._loading_packages = set()
83
84    def loadTestsFromTestCase(self, testCaseClass):
85        """Return a suite of all test cases contained in testCaseClass"""
86        if issubclass(testCaseClass, suite.TestSuite):
87            raise TypeError("Test cases should not be derived from "
88                            "TestSuite. Maybe you meant to derive from "
89                            "TestCase?")
90        testCaseNames = self.getTestCaseNames(testCaseClass)
91        if not testCaseNames and hasattr(testCaseClass, 'runTest'):
92            testCaseNames = ['runTest']
93        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
94        return loaded_suite
95
96    # XXX After Python 3.5, remove backward compatibility hacks for
97    # use_load_tests deprecation via *args and **kws.  See issue 16662.
98    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
99        """Return a suite of all test cases contained in the given module"""
100        # This method used to take an undocumented and unofficial
101        # use_load_tests argument.  For backward compatibility, we still
102        # accept the argument (which can also be the first position) but we
103        # ignore it and issue a deprecation warning if it's present.
104        if len(args) > 0 or 'use_load_tests' in kws:
105            warnings.warn('use_load_tests is deprecated and ignored',
106                          DeprecationWarning)
107            kws.pop('use_load_tests', None)
108        if len(args) > 1:
109            # Complain about the number of arguments, but don't forget the
110            # required `module` argument.
111            complaint = len(args) + 1
112            raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint))
113        if len(kws) != 0:
114            # Since the keyword arguments are unsorted (see PEP 468), just
115            # pick the alphabetically sorted first argument to complain about,
116            # if multiple were given.  At least the error message will be
117            # predictable.
118            complaint = sorted(kws)[0]
119            raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint))
120        tests = []
121        for name in dir(module):
122            obj = getattr(module, name)
123            if isinstance(obj, type) and issubclass(obj, case.TestCase):
124                tests.append(self.loadTestsFromTestCase(obj))
125
126        load_tests = getattr(module, 'load_tests', None)
127        tests = self.suiteClass(tests)
128        if load_tests is not None:
129            try:
130                return load_tests(self, tests, pattern)
131            except Exception as e:
132                error_case, error_message = _make_failed_load_tests(
133                    module.__name__, e, self.suiteClass)
134                self.errors.append(error_message)
135                return error_case
136        return tests
137
138    def loadTestsFromName(self, name, module=None):
139        """Return a suite of all test cases given a string specifier.
140
141        The name may resolve either to a module, a test case class, a
142        test method within a test case class, or a callable object which
143        returns a TestCase or TestSuite instance.
144
145        The method optionally resolves the names relative to a given module.
146        """
147        parts = name.split('.')
148        error_case, error_message = None, None
149        if module is None:
150            parts_copy = parts[:]
151            while parts_copy:
152                try:
153                    module_name = '.'.join(parts_copy)
154                    module = __import__(module_name)
155                    break
156                except ImportError:
157                    next_attribute = parts_copy.pop()
158                    # Last error so we can give it to the user if needed.
159                    error_case, error_message = _make_failed_import_test(
160                        next_attribute, self.suiteClass)
161                    if not parts_copy:
162                        # Even the top level import failed: report that error.
163                        self.errors.append(error_message)
164                        return error_case
165            parts = parts[1:]
166        obj = module
167        for part in parts:
168            try:
169                parent, obj = obj, getattr(obj, part)
170            except AttributeError as e:
171                # We can't traverse some part of the name.
172                if (getattr(obj, '__path__', None) is not None
173                    and error_case is not None):
174                    # This is a package (no __path__ per importlib docs), and we
175                    # encountered an error importing something. We cannot tell
176                    # the difference between package.WrongNameTestClass and
177                    # package.wrong_module_name so we just report the
178                    # ImportError - it is more informative.
179                    self.errors.append(error_message)
180                    return error_case
181                else:
182                    # Otherwise, we signal that an AttributeError has occurred.
183                    error_case, error_message = _make_failed_test(
184                        part, e, self.suiteClass,
185                        'Failed to access attribute:\n%s' % (
186                            traceback.format_exc(),))
187                    self.errors.append(error_message)
188                    return error_case
189
190        if isinstance(obj, types.ModuleType):
191            return self.loadTestsFromModule(obj)
192        elif isinstance(obj, type) and issubclass(obj, case.TestCase):
193            return self.loadTestsFromTestCase(obj)
194        elif (isinstance(obj, types.FunctionType) and
195              isinstance(parent, type) and
196              issubclass(parent, case.TestCase)):
197            name = parts[-1]
198            inst = parent(name)
199            # static methods follow a different path
200            if not isinstance(getattr(inst, name), types.FunctionType):
201                return self.suiteClass([inst])
202        elif isinstance(obj, suite.TestSuite):
203            return obj
204        if callable(obj):
205            test = obj()
206            if isinstance(test, suite.TestSuite):
207                return test
208            elif isinstance(test, case.TestCase):
209                return self.suiteClass([test])
210            else:
211                raise TypeError("calling %s returned %s, not a test" %
212                                (obj, test))
213        else:
214            raise TypeError("don't know how to make test from: %s" % obj)
215
216    def loadTestsFromNames(self, names, module=None):
217        """Return a suite of all test cases found using the given sequence
218        of string specifiers. See 'loadTestsFromName()'.
219        """
220        suites = [self.loadTestsFromName(name, module) for name in names]
221        return self.suiteClass(suites)
222
223    def getTestCaseNames(self, testCaseClass):
224        """Return a sorted sequence of method names found within testCaseClass
225        """
226        def shouldIncludeMethod(attrname):
227            if not attrname.startswith(self.testMethodPrefix):
228                return False
229            testFunc = getattr(testCaseClass, attrname)
230            if not callable(testFunc):
231                return False
232            fullName = f'%s.%s.%s' % (
233                testCaseClass.__module__, testCaseClass.__qualname__, attrname
234            )
235            return self.testNamePatterns is None or \
236                any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
237        testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
238        if self.sortTestMethodsUsing:
239            testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
240        return testFnNames
241
242    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
243        """Find and return all test modules from the specified start
244        directory, recursing into subdirectories to find them and return all
245        tests found within them. Only test files that match the pattern will
246        be loaded. (Using shell style pattern matching.)
247
248        All test modules must be importable from the top level of the project.
249        If the start directory is not the top level directory then the top
250        level directory must be specified separately.
251
252        If a test package name (directory with '__init__.py') matches the
253        pattern then the package will be checked for a 'load_tests' function. If
254        this exists then it will be called with (loader, tests, pattern) unless
255        the package has already had load_tests called from the same discovery
256        invocation, in which case the package module object is not scanned for
257        tests - this ensures that when a package uses discover to further
258        discover child tests that infinite recursion does not happen.
259
260        If load_tests exists then discovery does *not* recurse into the package,
261        load_tests is responsible for loading all tests in the package.
262
263        The pattern is deliberately not stored as a loader attribute so that
264        packages can continue discovery themselves. top_level_dir is stored so
265        load_tests does not need to pass this argument in to loader.discover().
266
267        Paths are sorted before being imported to ensure reproducible execution
268        order even on filesystems with non-alphabetical ordering like ext3/4.
269        """
270        set_implicit_top = False
271        if top_level_dir is None and self._top_level_dir is not None:
272            # make top_level_dir optional if called from load_tests in a package
273            top_level_dir = self._top_level_dir
274        elif top_level_dir is None:
275            set_implicit_top = True
276            top_level_dir = start_dir
277
278        top_level_dir = os.path.abspath(top_level_dir)
279
280        if not top_level_dir in sys.path:
281            # all test modules must be importable from the top level directory
282            # should we *unconditionally* put the start directory in first
283            # in sys.path to minimise likelihood of conflicts between installed
284            # modules and development versions?
285            sys.path.insert(0, top_level_dir)
286        self._top_level_dir = top_level_dir
287
288        is_not_importable = False
289        if os.path.isdir(os.path.abspath(start_dir)):
290            start_dir = os.path.abspath(start_dir)
291            if start_dir != top_level_dir:
292                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
293        else:
294            # support for discovery from dotted module names
295            try:
296                __import__(start_dir)
297            except ImportError:
298                is_not_importable = True
299            else:
300                the_module = sys.modules[start_dir]
301                top_part = start_dir.split('.')[0]
302                try:
303                    start_dir = os.path.abspath(
304                        os.path.dirname((the_module.__file__)))
305                except AttributeError:
306                    if the_module.__name__ in sys.builtin_module_names:
307                        # builtin module
308                        raise TypeError('Can not use builtin modules '
309                                        'as dotted module names') from None
310                    else:
311                        raise TypeError(
312                            f"don't know how to discover from {the_module!r}"
313                            ) from None
314
315                if set_implicit_top:
316                    self._top_level_dir = self._get_directory_containing_module(top_part)
317                    sys.path.remove(top_level_dir)
318
319        if is_not_importable:
320            raise ImportError('Start directory is not importable: %r' % start_dir)
321
322        tests = list(self._find_tests(start_dir, pattern))
323        return self.suiteClass(tests)
324
325    def _get_directory_containing_module(self, module_name):
326        module = sys.modules[module_name]
327        full_path = os.path.abspath(module.__file__)
328
329        if os.path.basename(full_path).lower().startswith('__init__.py'):
330            return os.path.dirname(os.path.dirname(full_path))
331        else:
332            # here we have been given a module rather than a package - so
333            # all we can do is search the *same* directory the module is in
334            # should an exception be raised instead
335            return os.path.dirname(full_path)
336
337    def _get_name_from_path(self, path):
338        if path == self._top_level_dir:
339            return '.'
340        path = _jython_aware_splitext(os.path.normpath(path))
341
342        _relpath = os.path.relpath(path, self._top_level_dir)
343        assert not os.path.isabs(_relpath), "Path must be within the project"
344        assert not _relpath.startswith('..'), "Path must be within the project"
345
346        name = _relpath.replace(os.path.sep, '.')
347        return name
348
349    def _get_module_from_name(self, name):
350        __import__(name)
351        return sys.modules[name]
352
353    def _match_path(self, path, full_path, pattern):
354        # override this method to use alternative matching strategy
355        return fnmatch(path, pattern)
356
357    def _find_tests(self, start_dir, pattern):
358        """Used by discovery. Yields test suites it loads."""
359        # Handle the __init__ in this package
360        name = self._get_name_from_path(start_dir)
361        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
362        # definition not a package).
363        if name != '.' and name not in self._loading_packages:
364            # name is in self._loading_packages while we have called into
365            # loadTestsFromModule with name.
366            tests, should_recurse = self._find_test_path(start_dir, pattern)
367            if tests is not None:
368                yield tests
369            if not should_recurse:
370                # Either an error occurred, or load_tests was used by the
371                # package.
372                return
373        # Handle the contents.
374        paths = sorted(os.listdir(start_dir))
375        for path in paths:
376            full_path = os.path.join(start_dir, path)
377            tests, should_recurse = self._find_test_path(full_path, pattern)
378            if tests is not None:
379                yield tests
380            if should_recurse:
381                # we found a package that didn't use load_tests.
382                name = self._get_name_from_path(full_path)
383                self._loading_packages.add(name)
384                try:
385                    yield from self._find_tests(full_path, pattern)
386                finally:
387                    self._loading_packages.discard(name)
388
389    def _find_test_path(self, full_path, pattern):
390        """Used by discovery.
391
392        Loads tests from a single file, or a directories' __init__.py when
393        passed the directory.
394
395        Returns a tuple (None_or_tests_from_file, should_recurse).
396        """
397        basename = os.path.basename(full_path)
398        if os.path.isfile(full_path):
399            if not VALID_MODULE_NAME.match(basename):
400                # valid Python identifiers only
401                return None, False
402            if not self._match_path(basename, full_path, pattern):
403                return None, False
404            # if the test file matches, load it
405            name = self._get_name_from_path(full_path)
406            try:
407                module = self._get_module_from_name(name)
408            except case.SkipTest as e:
409                return _make_skipped_test(name, e, self.suiteClass), False
410            except:
411                error_case, error_message = \
412                    _make_failed_import_test(name, self.suiteClass)
413                self.errors.append(error_message)
414                return error_case, False
415            else:
416                mod_file = os.path.abspath(
417                    getattr(module, '__file__', full_path))
418                realpath = _jython_aware_splitext(
419                    os.path.realpath(mod_file))
420                fullpath_noext = _jython_aware_splitext(
421                    os.path.realpath(full_path))
422                if realpath.lower() != fullpath_noext.lower():
423                    module_dir = os.path.dirname(realpath)
424                    mod_name = _jython_aware_splitext(
425                        os.path.basename(full_path))
426                    expected_dir = os.path.dirname(full_path)
427                    msg = ("%r module incorrectly imported from %r. Expected "
428                           "%r. Is this module globally installed?")
429                    raise ImportError(
430                        msg % (mod_name, module_dir, expected_dir))
431                return self.loadTestsFromModule(module, pattern=pattern), False
432        elif os.path.isdir(full_path):
433            if not os.path.isfile(os.path.join(full_path, '__init__.py')):
434                return None, False
435
436            load_tests = None
437            tests = None
438            name = self._get_name_from_path(full_path)
439            try:
440                package = self._get_module_from_name(name)
441            except case.SkipTest as e:
442                return _make_skipped_test(name, e, self.suiteClass), False
443            except:
444                error_case, error_message = \
445                    _make_failed_import_test(name, self.suiteClass)
446                self.errors.append(error_message)
447                return error_case, False
448            else:
449                load_tests = getattr(package, 'load_tests', None)
450                # Mark this package as being in load_tests (possibly ;))
451                self._loading_packages.add(name)
452                try:
453                    tests = self.loadTestsFromModule(package, pattern=pattern)
454                    if load_tests is not None:
455                        # loadTestsFromModule(package) has loaded tests for us.
456                        return tests, False
457                    return tests, True
458                finally:
459                    self._loading_packages.discard(name)
460        else:
461            return None, False
462
463
464defaultTestLoader = TestLoader()
465
466
467# These functions are considered obsolete for long time.
468# They will be removed in Python 3.13.
469
470def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
471    loader = TestLoader()
472    loader.sortTestMethodsUsing = sortUsing
473    loader.testMethodPrefix = prefix
474    loader.testNamePatterns = testNamePatterns
475    if suiteClass:
476        loader.suiteClass = suiteClass
477    return loader
478
479def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
480    import warnings
481    warnings.warn(
482        "unittest.getTestCaseNames() is deprecated and will be removed in Python 3.13. "
483        "Please use unittest.TestLoader.getTestCaseNames() instead.",
484        DeprecationWarning, stacklevel=2
485    )
486    return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
487
488def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
489              suiteClass=suite.TestSuite):
490    import warnings
491    warnings.warn(
492        "unittest.makeSuite() is deprecated and will be removed in Python 3.13. "
493        "Please use unittest.TestLoader.loadTestsFromTestCase() instead.",
494        DeprecationWarning, stacklevel=2
495    )
496    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
497        testCaseClass)
498
499def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
500                  suiteClass=suite.TestSuite):
501    import warnings
502    warnings.warn(
503        "unittest.findTestCases() is deprecated and will be removed in Python 3.13. "
504        "Please use unittest.TestLoader.loadTestsFromModule() instead.",
505        DeprecationWarning, stacklevel=2
506    )
507    return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
508        module)
509