1"""Loading unittests.""" 2 3import os 4import re 5import sys 6import traceback 7import types 8import functools 9import warnings 10 11from fnmatch import fnmatch, fnmatchcase 12 13from . import case, suite, util 14 15__unittest = True 16 17# what about .pyc (etc) 18# we would need to avoid loading the same tests multiple times 19# from '.py', *and* '.pyc' 20VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) 21 22 23class _FailedTest(case.TestCase): 24 _testMethodName = None 25 26 def __init__(self, method_name, exception): 27 self._exception = exception 28 super(_FailedTest, self).__init__(method_name) 29 30 def __getattr__(self, name): 31 if name != self._testMethodName: 32 return super(_FailedTest, self).__getattr__(name) 33 def testFailure(): 34 raise self._exception 35 return testFailure 36 37 38def _make_failed_import_test(name, suiteClass): 39 message = 'Failed to import test module: %s\n%s' % ( 40 name, traceback.format_exc()) 41 return _make_failed_test(name, ImportError(message), suiteClass, message) 42 43def _make_failed_load_tests(name, exception, suiteClass): 44 message = 'Failed to call load_tests:\n%s' % (traceback.format_exc(),) 45 return _make_failed_test( 46 name, exception, suiteClass, message) 47 48def _make_failed_test(methodname, exception, suiteClass, message): 49 test = _FailedTest(methodname, exception) 50 return suiteClass((test,)), message 51 52def _make_skipped_test(methodname, exception, suiteClass): 53 @case.skip(str(exception)) 54 def testSkipped(self): 55 pass 56 attrs = {methodname: testSkipped} 57 TestClass = type("ModuleSkipped", (case.TestCase,), attrs) 58 return suiteClass((TestClass(methodname),)) 59 60def _jython_aware_splitext(path): 61 if path.lower().endswith('$py.class'): 62 return path[:-9] 63 return os.path.splitext(path)[0] 64 65 66class TestLoader(object): 67 """ 68 This class is responsible for loading tests according to various criteria 69 and returning them wrapped in a TestSuite 70 """ 71 testMethodPrefix = 'test' 72 sortTestMethodsUsing = staticmethod(util.three_way_cmp) 73 testNamePatterns = None 74 suiteClass = suite.TestSuite 75 _top_level_dir = None 76 77 def __init__(self): 78 super(TestLoader, self).__init__() 79 self.errors = [] 80 # Tracks packages which we have called into via load_tests, to 81 # avoid infinite re-entrancy. 82 self._loading_packages = set() 83 84 def loadTestsFromTestCase(self, testCaseClass): 85 """Return a suite of all test cases contained in testCaseClass""" 86 if issubclass(testCaseClass, suite.TestSuite): 87 raise TypeError("Test cases should not be derived from " 88 "TestSuite. Maybe you meant to derive from " 89 "TestCase?") 90 testCaseNames = self.getTestCaseNames(testCaseClass) 91 if not testCaseNames and hasattr(testCaseClass, 'runTest'): 92 testCaseNames = ['runTest'] 93 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) 94 return loaded_suite 95 96 # XXX After Python 3.5, remove backward compatibility hacks for 97 # use_load_tests deprecation via *args and **kws. See issue 16662. 98 def loadTestsFromModule(self, module, *args, pattern=None, **kws): 99 """Return a suite of all test cases contained in the given module""" 100 # This method used to take an undocumented and unofficial 101 # use_load_tests argument. For backward compatibility, we still 102 # accept the argument (which can also be the first position) but we 103 # ignore it and issue a deprecation warning if it's present. 104 if len(args) > 0 or 'use_load_tests' in kws: 105 warnings.warn('use_load_tests is deprecated and ignored', 106 DeprecationWarning) 107 kws.pop('use_load_tests', None) 108 if len(args) > 1: 109 # Complain about the number of arguments, but don't forget the 110 # required `module` argument. 111 complaint = len(args) + 1 112 raise TypeError('loadTestsFromModule() takes 1 positional argument but {} were given'.format(complaint)) 113 if len(kws) != 0: 114 # Since the keyword arguments are unsorted (see PEP 468), just 115 # pick the alphabetically sorted first argument to complain about, 116 # if multiple were given. At least the error message will be 117 # predictable. 118 complaint = sorted(kws)[0] 119 raise TypeError("loadTestsFromModule() got an unexpected keyword argument '{}'".format(complaint)) 120 tests = [] 121 for name in dir(module): 122 obj = getattr(module, name) 123 if isinstance(obj, type) and issubclass(obj, case.TestCase): 124 tests.append(self.loadTestsFromTestCase(obj)) 125 126 load_tests = getattr(module, 'load_tests', None) 127 tests = self.suiteClass(tests) 128 if load_tests is not None: 129 try: 130 return load_tests(self, tests, pattern) 131 except Exception as e: 132 error_case, error_message = _make_failed_load_tests( 133 module.__name__, e, self.suiteClass) 134 self.errors.append(error_message) 135 return error_case 136 return tests 137 138 def loadTestsFromName(self, name, module=None): 139 """Return a suite of all test cases given a string specifier. 140 141 The name may resolve either to a module, a test case class, a 142 test method within a test case class, or a callable object which 143 returns a TestCase or TestSuite instance. 144 145 The method optionally resolves the names relative to a given module. 146 """ 147 parts = name.split('.') 148 error_case, error_message = None, None 149 if module is None: 150 parts_copy = parts[:] 151 while parts_copy: 152 try: 153 module_name = '.'.join(parts_copy) 154 module = __import__(module_name) 155 break 156 except ImportError: 157 next_attribute = parts_copy.pop() 158 # Last error so we can give it to the user if needed. 159 error_case, error_message = _make_failed_import_test( 160 next_attribute, self.suiteClass) 161 if not parts_copy: 162 # Even the top level import failed: report that error. 163 self.errors.append(error_message) 164 return error_case 165 parts = parts[1:] 166 obj = module 167 for part in parts: 168 try: 169 parent, obj = obj, getattr(obj, part) 170 except AttributeError as e: 171 # We can't traverse some part of the name. 172 if (getattr(obj, '__path__', None) is not None 173 and error_case is not None): 174 # This is a package (no __path__ per importlib docs), and we 175 # encountered an error importing something. We cannot tell 176 # the difference between package.WrongNameTestClass and 177 # package.wrong_module_name so we just report the 178 # ImportError - it is more informative. 179 self.errors.append(error_message) 180 return error_case 181 else: 182 # Otherwise, we signal that an AttributeError has occurred. 183 error_case, error_message = _make_failed_test( 184 part, e, self.suiteClass, 185 'Failed to access attribute:\n%s' % ( 186 traceback.format_exc(),)) 187 self.errors.append(error_message) 188 return error_case 189 190 if isinstance(obj, types.ModuleType): 191 return self.loadTestsFromModule(obj) 192 elif isinstance(obj, type) and issubclass(obj, case.TestCase): 193 return self.loadTestsFromTestCase(obj) 194 elif (isinstance(obj, types.FunctionType) and 195 isinstance(parent, type) and 196 issubclass(parent, case.TestCase)): 197 name = parts[-1] 198 inst = parent(name) 199 # static methods follow a different path 200 if not isinstance(getattr(inst, name), types.FunctionType): 201 return self.suiteClass([inst]) 202 elif isinstance(obj, suite.TestSuite): 203 return obj 204 if callable(obj): 205 test = obj() 206 if isinstance(test, suite.TestSuite): 207 return test 208 elif isinstance(test, case.TestCase): 209 return self.suiteClass([test]) 210 else: 211 raise TypeError("calling %s returned %s, not a test" % 212 (obj, test)) 213 else: 214 raise TypeError("don't know how to make test from: %s" % obj) 215 216 def loadTestsFromNames(self, names, module=None): 217 """Return a suite of all test cases found using the given sequence 218 of string specifiers. See 'loadTestsFromName()'. 219 """ 220 suites = [self.loadTestsFromName(name, module) for name in names] 221 return self.suiteClass(suites) 222 223 def getTestCaseNames(self, testCaseClass): 224 """Return a sorted sequence of method names found within testCaseClass 225 """ 226 def shouldIncludeMethod(attrname): 227 if not attrname.startswith(self.testMethodPrefix): 228 return False 229 testFunc = getattr(testCaseClass, attrname) 230 if not callable(testFunc): 231 return False 232 fullName = f'%s.%s.%s' % ( 233 testCaseClass.__module__, testCaseClass.__qualname__, attrname 234 ) 235 return self.testNamePatterns is None or \ 236 any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns) 237 testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass))) 238 if self.sortTestMethodsUsing: 239 testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) 240 return testFnNames 241 242 def discover(self, start_dir, pattern='test*.py', top_level_dir=None): 243 """Find and return all test modules from the specified start 244 directory, recursing into subdirectories to find them and return all 245 tests found within them. Only test files that match the pattern will 246 be loaded. (Using shell style pattern matching.) 247 248 All test modules must be importable from the top level of the project. 249 If the start directory is not the top level directory then the top 250 level directory must be specified separately. 251 252 If a test package name (directory with '__init__.py') matches the 253 pattern then the package will be checked for a 'load_tests' function. If 254 this exists then it will be called with (loader, tests, pattern) unless 255 the package has already had load_tests called from the same discovery 256 invocation, in which case the package module object is not scanned for 257 tests - this ensures that when a package uses discover to further 258 discover child tests that infinite recursion does not happen. 259 260 If load_tests exists then discovery does *not* recurse into the package, 261 load_tests is responsible for loading all tests in the package. 262 263 The pattern is deliberately not stored as a loader attribute so that 264 packages can continue discovery themselves. top_level_dir is stored so 265 load_tests does not need to pass this argument in to loader.discover(). 266 267 Paths are sorted before being imported to ensure reproducible execution 268 order even on filesystems with non-alphabetical ordering like ext3/4. 269 """ 270 set_implicit_top = False 271 if top_level_dir is None and self._top_level_dir is not None: 272 # make top_level_dir optional if called from load_tests in a package 273 top_level_dir = self._top_level_dir 274 elif top_level_dir is None: 275 set_implicit_top = True 276 top_level_dir = start_dir 277 278 top_level_dir = os.path.abspath(top_level_dir) 279 280 if not top_level_dir in sys.path: 281 # all test modules must be importable from the top level directory 282 # should we *unconditionally* put the start directory in first 283 # in sys.path to minimise likelihood of conflicts between installed 284 # modules and development versions? 285 sys.path.insert(0, top_level_dir) 286 self._top_level_dir = top_level_dir 287 288 is_not_importable = False 289 if os.path.isdir(os.path.abspath(start_dir)): 290 start_dir = os.path.abspath(start_dir) 291 if start_dir != top_level_dir: 292 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) 293 else: 294 # support for discovery from dotted module names 295 try: 296 __import__(start_dir) 297 except ImportError: 298 is_not_importable = True 299 else: 300 the_module = sys.modules[start_dir] 301 top_part = start_dir.split('.')[0] 302 try: 303 start_dir = os.path.abspath( 304 os.path.dirname((the_module.__file__))) 305 except AttributeError: 306 if the_module.__name__ in sys.builtin_module_names: 307 # builtin module 308 raise TypeError('Can not use builtin modules ' 309 'as dotted module names') from None 310 else: 311 raise TypeError( 312 f"don't know how to discover from {the_module!r}" 313 ) from None 314 315 if set_implicit_top: 316 self._top_level_dir = self._get_directory_containing_module(top_part) 317 sys.path.remove(top_level_dir) 318 319 if is_not_importable: 320 raise ImportError('Start directory is not importable: %r' % start_dir) 321 322 tests = list(self._find_tests(start_dir, pattern)) 323 return self.suiteClass(tests) 324 325 def _get_directory_containing_module(self, module_name): 326 module = sys.modules[module_name] 327 full_path = os.path.abspath(module.__file__) 328 329 if os.path.basename(full_path).lower().startswith('__init__.py'): 330 return os.path.dirname(os.path.dirname(full_path)) 331 else: 332 # here we have been given a module rather than a package - so 333 # all we can do is search the *same* directory the module is in 334 # should an exception be raised instead 335 return os.path.dirname(full_path) 336 337 def _get_name_from_path(self, path): 338 if path == self._top_level_dir: 339 return '.' 340 path = _jython_aware_splitext(os.path.normpath(path)) 341 342 _relpath = os.path.relpath(path, self._top_level_dir) 343 assert not os.path.isabs(_relpath), "Path must be within the project" 344 assert not _relpath.startswith('..'), "Path must be within the project" 345 346 name = _relpath.replace(os.path.sep, '.') 347 return name 348 349 def _get_module_from_name(self, name): 350 __import__(name) 351 return sys.modules[name] 352 353 def _match_path(self, path, full_path, pattern): 354 # override this method to use alternative matching strategy 355 return fnmatch(path, pattern) 356 357 def _find_tests(self, start_dir, pattern): 358 """Used by discovery. Yields test suites it loads.""" 359 # Handle the __init__ in this package 360 name = self._get_name_from_path(start_dir) 361 # name is '.' when start_dir == top_level_dir (and top_level_dir is by 362 # definition not a package). 363 if name != '.' and name not in self._loading_packages: 364 # name is in self._loading_packages while we have called into 365 # loadTestsFromModule with name. 366 tests, should_recurse = self._find_test_path(start_dir, pattern) 367 if tests is not None: 368 yield tests 369 if not should_recurse: 370 # Either an error occurred, or load_tests was used by the 371 # package. 372 return 373 # Handle the contents. 374 paths = sorted(os.listdir(start_dir)) 375 for path in paths: 376 full_path = os.path.join(start_dir, path) 377 tests, should_recurse = self._find_test_path(full_path, pattern) 378 if tests is not None: 379 yield tests 380 if should_recurse: 381 # we found a package that didn't use load_tests. 382 name = self._get_name_from_path(full_path) 383 self._loading_packages.add(name) 384 try: 385 yield from self._find_tests(full_path, pattern) 386 finally: 387 self._loading_packages.discard(name) 388 389 def _find_test_path(self, full_path, pattern): 390 """Used by discovery. 391 392 Loads tests from a single file, or a directories' __init__.py when 393 passed the directory. 394 395 Returns a tuple (None_or_tests_from_file, should_recurse). 396 """ 397 basename = os.path.basename(full_path) 398 if os.path.isfile(full_path): 399 if not VALID_MODULE_NAME.match(basename): 400 # valid Python identifiers only 401 return None, False 402 if not self._match_path(basename, full_path, pattern): 403 return None, False 404 # if the test file matches, load it 405 name = self._get_name_from_path(full_path) 406 try: 407 module = self._get_module_from_name(name) 408 except case.SkipTest as e: 409 return _make_skipped_test(name, e, self.suiteClass), False 410 except: 411 error_case, error_message = \ 412 _make_failed_import_test(name, self.suiteClass) 413 self.errors.append(error_message) 414 return error_case, False 415 else: 416 mod_file = os.path.abspath( 417 getattr(module, '__file__', full_path)) 418 realpath = _jython_aware_splitext( 419 os.path.realpath(mod_file)) 420 fullpath_noext = _jython_aware_splitext( 421 os.path.realpath(full_path)) 422 if realpath.lower() != fullpath_noext.lower(): 423 module_dir = os.path.dirname(realpath) 424 mod_name = _jython_aware_splitext( 425 os.path.basename(full_path)) 426 expected_dir = os.path.dirname(full_path) 427 msg = ("%r module incorrectly imported from %r. Expected " 428 "%r. Is this module globally installed?") 429 raise ImportError( 430 msg % (mod_name, module_dir, expected_dir)) 431 return self.loadTestsFromModule(module, pattern=pattern), False 432 elif os.path.isdir(full_path): 433 if not os.path.isfile(os.path.join(full_path, '__init__.py')): 434 return None, False 435 436 load_tests = None 437 tests = None 438 name = self._get_name_from_path(full_path) 439 try: 440 package = self._get_module_from_name(name) 441 except case.SkipTest as e: 442 return _make_skipped_test(name, e, self.suiteClass), False 443 except: 444 error_case, error_message = \ 445 _make_failed_import_test(name, self.suiteClass) 446 self.errors.append(error_message) 447 return error_case, False 448 else: 449 load_tests = getattr(package, 'load_tests', None) 450 # Mark this package as being in load_tests (possibly ;)) 451 self._loading_packages.add(name) 452 try: 453 tests = self.loadTestsFromModule(package, pattern=pattern) 454 if load_tests is not None: 455 # loadTestsFromModule(package) has loaded tests for us. 456 return tests, False 457 return tests, True 458 finally: 459 self._loading_packages.discard(name) 460 else: 461 return None, False 462 463 464defaultTestLoader = TestLoader() 465 466 467# These functions are considered obsolete for long time. 468# They will be removed in Python 3.13. 469 470def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None): 471 loader = TestLoader() 472 loader.sortTestMethodsUsing = sortUsing 473 loader.testMethodPrefix = prefix 474 loader.testNamePatterns = testNamePatterns 475 if suiteClass: 476 loader.suiteClass = suiteClass 477 return loader 478 479def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None): 480 import warnings 481 warnings.warn( 482 "unittest.getTestCaseNames() is deprecated and will be removed in Python 3.13. " 483 "Please use unittest.TestLoader.getTestCaseNames() instead.", 484 DeprecationWarning, stacklevel=2 485 ) 486 return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass) 487 488def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp, 489 suiteClass=suite.TestSuite): 490 import warnings 491 warnings.warn( 492 "unittest.makeSuite() is deprecated and will be removed in Python 3.13. " 493 "Please use unittest.TestLoader.loadTestsFromTestCase() instead.", 494 DeprecationWarning, stacklevel=2 495 ) 496 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase( 497 testCaseClass) 498 499def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp, 500 suiteClass=suite.TestSuite): 501 import warnings 502 warnings.warn( 503 "unittest.findTestCases() is deprecated and will be removed in Python 3.13. " 504 "Please use unittest.TestLoader.loadTestsFromModule() instead.", 505 DeprecationWarning, stacklevel=2 506 ) 507 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\ 508 module) 509