1from __future__ import annotations 2 3import glob 4import os 5import sys 6from pathlib import Path 7 8 9CPP_TEST_PREFIX = "cpp" 10CPP_TEST_PATH = "build/bin" 11CPP_TESTS_DIR = os.path.abspath(os.getenv("CPP_TESTS_DIR", default=CPP_TEST_PATH)) 12REPO_ROOT = Path(__file__).resolve().parent.parent.parent 13 14 15def parse_test_module(test: str) -> str: 16 return test.split(".")[0] 17 18 19def discover_tests( 20 base_dir: Path = REPO_ROOT / "test", 21 cpp_tests_dir: str | Path | None = None, 22 blocklisted_patterns: list[str] | None = None, 23 blocklisted_tests: list[str] | None = None, 24 extra_tests: list[str] | None = None, 25) -> list[str]: 26 """ 27 Searches for all python files starting with test_ excluding one specified by patterns. 28 If cpp_tests_dir is provided, also scan for all C++ tests under that directory. They 29 are usually found in build/bin 30 """ 31 32 def skip_test_p(name: str) -> bool: 33 rc = False 34 if blocklisted_patterns is not None: 35 rc |= any(name.startswith(pattern) for pattern in blocklisted_patterns) 36 if blocklisted_tests is not None: 37 rc |= name in blocklisted_tests 38 return rc 39 40 # This supports symlinks, so we can link domain library tests to PyTorch test directory 41 all_py_files = [ 42 Path(p) for p in glob.glob(f"{base_dir}/**/test_*.py", recursive=True) 43 ] 44 45 cpp_tests_dir = ( 46 f"{base_dir.parent}/{CPP_TEST_PATH}" if cpp_tests_dir is None else cpp_tests_dir 47 ) 48 # CPP test files are located under pytorch/build/bin. Unlike Python test, C++ tests 49 # are just binaries and could have any name, i.e. basic or atest 50 all_cpp_files = [ 51 Path(p) for p in glob.glob(f"{cpp_tests_dir}/**/*", recursive=True) 52 ] 53 54 rc = [str(fname.relative_to(base_dir))[:-3] for fname in all_py_files] 55 # Add the cpp prefix for C++ tests so that we can tell them apart 56 rc.extend( 57 [ 58 parse_test_module(f"{CPP_TEST_PREFIX}/{fname.relative_to(cpp_tests_dir)}") 59 for fname in all_cpp_files 60 ] 61 ) 62 63 # Invert slashes on Windows 64 if sys.platform == "win32": 65 rc = [name.replace("\\", "/") for name in rc] 66 rc = [test for test in rc if not skip_test_p(test)] 67 if extra_tests is not None: 68 rc += extra_tests 69 return sorted(rc) 70 71 72TESTS = discover_tests( 73 cpp_tests_dir=CPP_TESTS_DIR, 74 blocklisted_patterns=[ 75 "ao", 76 "bottleneck_test", 77 "custom_backend", 78 "custom_operator", 79 "fx", # executed by test_fx.py 80 "jit", # executed by test_jit.py 81 "mobile", 82 "onnx_caffe2", 83 "package", # executed by test_package.py 84 "quantization", # executed by test_quantization.py 85 "autograd", # executed by test_autograd.py 86 ], 87 blocklisted_tests=[ 88 "test_bundled_images", 89 "test_cpp_extensions_aot", 90 "test_determination", 91 "test_jit_fuser", 92 "test_jit_simple", 93 "test_jit_string", 94 "test_kernel_launch_checks", 95 "test_nnapi", 96 "test_static_runtime", 97 "test_throughput_benchmark", 98 "distributed/bin/test_script", 99 "distributed/elastic/multiprocessing/bin/test_script", 100 "distributed/launcher/bin/test_script", 101 "distributed/launcher/bin/test_script_init_method", 102 "distributed/launcher/bin/test_script_is_torchelastic_launched", 103 "distributed/launcher/bin/test_script_local_rank", 104 "distributed/test_c10d_spawn", 105 "distributions/test_transforms", 106 "distributions/test_utils", 107 "test/inductor/test_aot_inductor_utils", 108 "onnx/test_pytorch_onnx_onnxruntime_cuda", 109 "onnx/test_models", 110 # These are not C++ tests 111 f"{CPP_TEST_PREFIX}/CMakeFiles", 112 f"{CPP_TEST_PREFIX}/CTestTestfile.cmake", 113 f"{CPP_TEST_PREFIX}/Makefile", 114 f"{CPP_TEST_PREFIX}/cmake_install.cmake", 115 f"{CPP_TEST_PREFIX}/c10_intrusive_ptr_benchmark", 116 f"{CPP_TEST_PREFIX}/example_allreduce", 117 f"{CPP_TEST_PREFIX}/parallel_benchmark", 118 f"{CPP_TEST_PREFIX}/protoc", 119 f"{CPP_TEST_PREFIX}/protoc-3.13.0.0", 120 f"{CPP_TEST_PREFIX}/torch_shm_manager", 121 f"{CPP_TEST_PREFIX}/tutorial_tensorexpr", 122 ], 123 extra_tests=[ 124 "test_cpp_extensions_aot_ninja", 125 "test_cpp_extensions_aot_no_ninja", 126 "distributed/elastic/timer/api_test", 127 "distributed/elastic/timer/local_timer_example", 128 "distributed/elastic/timer/local_timer_test", 129 "distributed/elastic/events/lib_test", 130 "distributed/elastic/metrics/api_test", 131 "distributed/elastic/utils/logging_test", 132 "distributed/elastic/utils/util_test", 133 "distributed/elastic/utils/distributed_test", 134 "distributed/elastic/multiprocessing/api_test", 135 "doctests", 136 "test_autoload_enable", 137 "test_autoload_disable", 138 ], 139) 140 141 142if __name__ == "__main__": 143 print(TESTS) 144