from __future__ import annotations import math import os import subprocess from pathlib import Path from typing import Callable, Sequence from tools.stats.import_test_stats import get_disabled_tests from tools.testing.test_run import ShardedTest, TestRun REPO_ROOT = Path(__file__).resolve().parent.parent.parent IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") USE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT # NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job # to ensure that sharding is consistent, NUM_PROCS is the actual number of procs # used to run tests. If they are not equal, the only consequence should be # unequal shards. IS_ROCM = os.path.exists("/opt/rocm") NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2 NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2 THRESHOLD = 60 * 10 # 10 minutes # See Note [ROCm parallel CI testing] # Special logic for ROCm GHA runners to query number of GPUs available. # torch.version.hip was not available to check if this was a ROCm self-hosted runner. # Must check for ROCm runner in another way. We look for /opt/rocm directory. if IS_ROCM and not IS_MEM_LEAK_CHECK: try: # This is the same logic used in GHA health check, see .github/templates/common.yml.j2 lines = ( subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n") ) count = 0 for line in lines: if " gfx" in line: count += 1 assert count > 0 # there must be at least 1 GPU # Limiting to 8 GPUs(PROCS) NUM_PROCS = min(count, 8) except subprocess.CalledProcessError as e: # The safe default for ROCm GHA runners is to run tests serially. NUM_PROCS = 1 class ShardJob: def __init__(self) -> None: self.serial: list[ShardedTest] = [] self.parallel: list[ShardedTest] = [] def get_total_time(self) -> float: """Default is the value for which to substitute if a test has no time""" procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)] for test in self.parallel: min_index = procs.index(min(procs)) procs[min_index] += test.get_time() time = max(procs) + sum(test.get_time() for test in self.serial) return time def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]: return (self.get_total_time(), self.serial + self.parallel) def get_with_pytest_shard( tests: Sequence[TestRun], test_file_times: dict[str, float], test_class_times: dict[str, dict[str, float]] | None, ) -> list[ShardedTest]: sharded_tests: list[ShardedTest] = [] for test in tests: duration = get_duration(test, test_file_times, test_class_times or {}) if duration and duration > THRESHOLD: num_shards = math.ceil(duration / THRESHOLD) for i in range(num_shards): sharded_tests.append( ShardedTest(test, i + 1, num_shards, duration / num_shards) ) else: sharded_tests.append(ShardedTest(test, 1, 1, duration)) return sharded_tests def get_duration( test: TestRun, test_file_times: dict[str, float], test_class_times: dict[str, dict[str, float]], ) -> float | None: """Calculate the time for a TestRun based on the given test_file_times and test_class_times. Returns None if the time is unknown.""" file_duration = test_file_times.get(test.test_file, None) if test.is_full_file(): return file_duration def get_duration_for_classes( test_file: str, test_classes: frozenset[str] ) -> float | None: duration: float = 0 for test_class in test_classes: class_duration = test_class_times.get(test_file, {}).get(test_class, None) if class_duration is None: return None duration += class_duration return duration included = test.included() excluded = test.excluded() included_classes_duration = get_duration_for_classes(test.test_file, included) excluded_classes_duration = get_duration_for_classes(test.test_file, excluded) if included_classes_duration is None or excluded_classes_duration is None: # Didn't get the time for all classes, so time is unknown return None if included: return included_classes_duration assert ( excluded ), f"TestRun {test} is not full file but doesn't have included or excluded classes" if file_duration is None: return None return file_duration - excluded_classes_duration def shard( sharded_jobs: list[ShardJob], pytest_sharded_tests: Sequence[ShardedTest], estimated_time_limit: float | None = None, serial: bool = False, ) -> None: # Modifies sharded_jobs in place if len(sharded_jobs) == 0: assert ( len(pytest_sharded_tests) == 0 ), "No shards provided but there are tests to shard" return round_robin_index = 0 def _get_min_sharded_job( sharded_jobs: list[ShardJob], test: ShardedTest ) -> ShardJob: if test.time is None: nonlocal round_robin_index job = sharded_jobs[round_robin_index % len(sharded_jobs)] round_robin_index += 1 return job return min(sharded_jobs, key=lambda j: j.get_total_time()) def _shard_serial( tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] ) -> None: assert estimated_time_limit is not None, "Estimated time limit must be provided" new_sharded_jobs = sharded_jobs for test in tests: if ( len(sharded_jobs) > 1 and sharded_jobs[-1].get_total_time() > estimated_time_limit ): new_sharded_jobs = sharded_jobs[:-1] min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test) min_sharded_job.serial.append(test) def _shard_parallel( tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] ) -> None: for test in tests: min_sharded_job = _get_min_sharded_job(sharded_jobs, test) min_sharded_job.parallel.append(test) if serial: _shard_serial(pytest_sharded_tests, sharded_jobs) else: _shard_parallel(pytest_sharded_tests, sharded_jobs) return def calculate_shards( num_shards: int, tests: Sequence[TestRun], test_file_times: dict[str, float], test_class_times: dict[str, dict[str, float]] | None, must_serial: Callable[[str], bool] | None = None, sort_by_time: bool = True, ) -> list[tuple[float, list[ShardedTest]]]: must_serial = must_serial or (lambda x: True) test_class_times = test_class_times or {} # Divide tests into pytest shards if sort_by_time: known_tests = [ x for x in tests if get_duration(x, test_file_times, test_class_times) is not None ] unknown_tests = [x for x in tests if x not in known_tests] pytest_sharded_tests = sorted( get_with_pytest_shard(known_tests, test_file_times, test_class_times), key=lambda j: j.get_time(), reverse=True, ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times) else: pytest_sharded_tests = get_with_pytest_shard( tests, test_file_times, test_class_times ) del tests serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)] parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests] serial_time = sum(test.get_time() for test in serial_tests) parallel_time = sum(test.get_time() for test in parallel_tests) total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC estimated_time_per_shard = total_time / num_shards # Separate serial tests from parallel tests as much as possible to maximize # parallelism by putting all the serial tests on the first num_serial_shards # shards. The estimated_time_limit is the estimated time it should take for # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min # of parallel tests, 6 shards, and 2 procs per machine, we would expect each # machine to take 3 min and should aim for 3 serial shards, with shards 1 # and 2 taking 3 min and shard 3 taking 2 min. The estimated time limit # would be 2 min. This ensures that the first few shard contains as many # serial tests as possible and as few parallel tests as possible. The least # filled/last (in the example, the 3rd) shard may contain a lot of both # serial and parallel tests. estimated_time_limit = 0.0 if estimated_time_per_shard != 0: estimated_time_limit = serial_time % estimated_time_per_shard if estimated_time_limit <= 0.01: estimated_time_limit = estimated_time_per_shard if total_time == 0: num_serial_shards = num_shards else: num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1) sharded_jobs = [ShardJob() for _ in range(num_shards)] shard( sharded_jobs=sharded_jobs[:num_serial_shards], pytest_sharded_tests=serial_tests, estimated_time_limit=estimated_time_limit, serial=True, ) shard( sharded_jobs=sharded_jobs, pytest_sharded_tests=parallel_tests, serial=False, ) return [job.convert_to_tuple() for job in sharded_jobs] def get_test_case_configs(dirpath: str) -> None: get_disabled_tests(dirpath=dirpath)