1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport math 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport subprocess 6*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Sequence 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.import_test_stats import get_disabled_tests 10*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_run import ShardedTest, TestRun 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).resolve().parent.parent.parent 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard WorkerIS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" 16*da0073e9SAndroid Build Coastguard WorkerBUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") 17*da0073e9SAndroid Build Coastguard WorkerUSE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job 20*da0073e9SAndroid Build Coastguard Worker# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs 21*da0073e9SAndroid Build Coastguard Worker# used to run tests. If they are not equal, the only consequence should be 22*da0073e9SAndroid Build Coastguard Worker# unequal shards. 23*da0073e9SAndroid Build Coastguard WorkerIS_ROCM = os.path.exists("/opt/rocm") 24*da0073e9SAndroid Build Coastguard WorkerNUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2 25*da0073e9SAndroid Build Coastguard WorkerNUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2 26*da0073e9SAndroid Build Coastguard WorkerTHRESHOLD = 60 * 10 # 10 minutes 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker# See Note [ROCm parallel CI testing] 29*da0073e9SAndroid Build Coastguard Worker# Special logic for ROCm GHA runners to query number of GPUs available. 30*da0073e9SAndroid Build Coastguard Worker# torch.version.hip was not available to check if this was a ROCm self-hosted runner. 31*da0073e9SAndroid Build Coastguard Worker# Must check for ROCm runner in another way. We look for /opt/rocm directory. 32*da0073e9SAndroid Build Coastguard Workerif IS_ROCM and not IS_MEM_LEAK_CHECK: 33*da0073e9SAndroid Build Coastguard Worker try: 34*da0073e9SAndroid Build Coastguard Worker # This is the same logic used in GHA health check, see .github/templates/common.yml.j2 35*da0073e9SAndroid Build Coastguard Worker lines = ( 36*da0073e9SAndroid Build Coastguard Worker subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n") 37*da0073e9SAndroid Build Coastguard Worker ) 38*da0073e9SAndroid Build Coastguard Worker count = 0 39*da0073e9SAndroid Build Coastguard Worker for line in lines: 40*da0073e9SAndroid Build Coastguard Worker if " gfx" in line: 41*da0073e9SAndroid Build Coastguard Worker count += 1 42*da0073e9SAndroid Build Coastguard Worker assert count > 0 # there must be at least 1 GPU 43*da0073e9SAndroid Build Coastguard Worker # Limiting to 8 GPUs(PROCS) 44*da0073e9SAndroid Build Coastguard Worker NUM_PROCS = min(count, 8) 45*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 46*da0073e9SAndroid Build Coastguard Worker # The safe default for ROCm GHA runners is to run tests serially. 47*da0073e9SAndroid Build Coastguard Worker NUM_PROCS = 1 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerclass ShardJob: 51*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 52*da0073e9SAndroid Build Coastguard Worker self.serial: list[ShardedTest] = [] 53*da0073e9SAndroid Build Coastguard Worker self.parallel: list[ShardedTest] = [] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def get_total_time(self) -> float: 56*da0073e9SAndroid Build Coastguard Worker """Default is the value for which to substitute if a test has no time""" 57*da0073e9SAndroid Build Coastguard Worker procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)] 58*da0073e9SAndroid Build Coastguard Worker for test in self.parallel: 59*da0073e9SAndroid Build Coastguard Worker min_index = procs.index(min(procs)) 60*da0073e9SAndroid Build Coastguard Worker procs[min_index] += test.get_time() 61*da0073e9SAndroid Build Coastguard Worker time = max(procs) + sum(test.get_time() for test in self.serial) 62*da0073e9SAndroid Build Coastguard Worker return time 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]: 65*da0073e9SAndroid Build Coastguard Worker return (self.get_total_time(), self.serial + self.parallel) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerdef get_with_pytest_shard( 69*da0073e9SAndroid Build Coastguard Worker tests: Sequence[TestRun], 70*da0073e9SAndroid Build Coastguard Worker test_file_times: dict[str, float], 71*da0073e9SAndroid Build Coastguard Worker test_class_times: dict[str, dict[str, float]] | None, 72*da0073e9SAndroid Build Coastguard Worker) -> list[ShardedTest]: 73*da0073e9SAndroid Build Coastguard Worker sharded_tests: list[ShardedTest] = [] 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker for test in tests: 76*da0073e9SAndroid Build Coastguard Worker duration = get_duration(test, test_file_times, test_class_times or {}) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker if duration and duration > THRESHOLD: 79*da0073e9SAndroid Build Coastguard Worker num_shards = math.ceil(duration / THRESHOLD) 80*da0073e9SAndroid Build Coastguard Worker for i in range(num_shards): 81*da0073e9SAndroid Build Coastguard Worker sharded_tests.append( 82*da0073e9SAndroid Build Coastguard Worker ShardedTest(test, i + 1, num_shards, duration / num_shards) 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker sharded_tests.append(ShardedTest(test, 1, 1, duration)) 86*da0073e9SAndroid Build Coastguard Worker return sharded_tests 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerdef get_duration( 90*da0073e9SAndroid Build Coastguard Worker test: TestRun, 91*da0073e9SAndroid Build Coastguard Worker test_file_times: dict[str, float], 92*da0073e9SAndroid Build Coastguard Worker test_class_times: dict[str, dict[str, float]], 93*da0073e9SAndroid Build Coastguard Worker) -> float | None: 94*da0073e9SAndroid Build Coastguard Worker """Calculate the time for a TestRun based on the given test_file_times and 95*da0073e9SAndroid Build Coastguard Worker test_class_times. Returns None if the time is unknown.""" 96*da0073e9SAndroid Build Coastguard Worker file_duration = test_file_times.get(test.test_file, None) 97*da0073e9SAndroid Build Coastguard Worker if test.is_full_file(): 98*da0073e9SAndroid Build Coastguard Worker return file_duration 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def get_duration_for_classes( 101*da0073e9SAndroid Build Coastguard Worker test_file: str, test_classes: frozenset[str] 102*da0073e9SAndroid Build Coastguard Worker ) -> float | None: 103*da0073e9SAndroid Build Coastguard Worker duration: float = 0 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker for test_class in test_classes: 106*da0073e9SAndroid Build Coastguard Worker class_duration = test_class_times.get(test_file, {}).get(test_class, None) 107*da0073e9SAndroid Build Coastguard Worker if class_duration is None: 108*da0073e9SAndroid Build Coastguard Worker return None 109*da0073e9SAndroid Build Coastguard Worker duration += class_duration 110*da0073e9SAndroid Build Coastguard Worker return duration 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker included = test.included() 113*da0073e9SAndroid Build Coastguard Worker excluded = test.excluded() 114*da0073e9SAndroid Build Coastguard Worker included_classes_duration = get_duration_for_classes(test.test_file, included) 115*da0073e9SAndroid Build Coastguard Worker excluded_classes_duration = get_duration_for_classes(test.test_file, excluded) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker if included_classes_duration is None or excluded_classes_duration is None: 118*da0073e9SAndroid Build Coastguard Worker # Didn't get the time for all classes, so time is unknown 119*da0073e9SAndroid Build Coastguard Worker return None 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker if included: 122*da0073e9SAndroid Build Coastguard Worker return included_classes_duration 123*da0073e9SAndroid Build Coastguard Worker assert ( 124*da0073e9SAndroid Build Coastguard Worker excluded 125*da0073e9SAndroid Build Coastguard Worker ), f"TestRun {test} is not full file but doesn't have included or excluded classes" 126*da0073e9SAndroid Build Coastguard Worker if file_duration is None: 127*da0073e9SAndroid Build Coastguard Worker return None 128*da0073e9SAndroid Build Coastguard Worker return file_duration - excluded_classes_duration 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Workerdef shard( 132*da0073e9SAndroid Build Coastguard Worker sharded_jobs: list[ShardJob], 133*da0073e9SAndroid Build Coastguard Worker pytest_sharded_tests: Sequence[ShardedTest], 134*da0073e9SAndroid Build Coastguard Worker estimated_time_limit: float | None = None, 135*da0073e9SAndroid Build Coastguard Worker serial: bool = False, 136*da0073e9SAndroid Build Coastguard Worker) -> None: 137*da0073e9SAndroid Build Coastguard Worker # Modifies sharded_jobs in place 138*da0073e9SAndroid Build Coastguard Worker if len(sharded_jobs) == 0: 139*da0073e9SAndroid Build Coastguard Worker assert ( 140*da0073e9SAndroid Build Coastguard Worker len(pytest_sharded_tests) == 0 141*da0073e9SAndroid Build Coastguard Worker ), "No shards provided but there are tests to shard" 142*da0073e9SAndroid Build Coastguard Worker return 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker round_robin_index = 0 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def _get_min_sharded_job( 147*da0073e9SAndroid Build Coastguard Worker sharded_jobs: list[ShardJob], test: ShardedTest 148*da0073e9SAndroid Build Coastguard Worker ) -> ShardJob: 149*da0073e9SAndroid Build Coastguard Worker if test.time is None: 150*da0073e9SAndroid Build Coastguard Worker nonlocal round_robin_index 151*da0073e9SAndroid Build Coastguard Worker job = sharded_jobs[round_robin_index % len(sharded_jobs)] 152*da0073e9SAndroid Build Coastguard Worker round_robin_index += 1 153*da0073e9SAndroid Build Coastguard Worker return job 154*da0073e9SAndroid Build Coastguard Worker return min(sharded_jobs, key=lambda j: j.get_total_time()) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def _shard_serial( 157*da0073e9SAndroid Build Coastguard Worker tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] 158*da0073e9SAndroid Build Coastguard Worker ) -> None: 159*da0073e9SAndroid Build Coastguard Worker assert estimated_time_limit is not None, "Estimated time limit must be provided" 160*da0073e9SAndroid Build Coastguard Worker new_sharded_jobs = sharded_jobs 161*da0073e9SAndroid Build Coastguard Worker for test in tests: 162*da0073e9SAndroid Build Coastguard Worker if ( 163*da0073e9SAndroid Build Coastguard Worker len(sharded_jobs) > 1 164*da0073e9SAndroid Build Coastguard Worker and sharded_jobs[-1].get_total_time() > estimated_time_limit 165*da0073e9SAndroid Build Coastguard Worker ): 166*da0073e9SAndroid Build Coastguard Worker new_sharded_jobs = sharded_jobs[:-1] 167*da0073e9SAndroid Build Coastguard Worker min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test) 168*da0073e9SAndroid Build Coastguard Worker min_sharded_job.serial.append(test) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def _shard_parallel( 171*da0073e9SAndroid Build Coastguard Worker tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] 172*da0073e9SAndroid Build Coastguard Worker ) -> None: 173*da0073e9SAndroid Build Coastguard Worker for test in tests: 174*da0073e9SAndroid Build Coastguard Worker min_sharded_job = _get_min_sharded_job(sharded_jobs, test) 175*da0073e9SAndroid Build Coastguard Worker min_sharded_job.parallel.append(test) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker if serial: 178*da0073e9SAndroid Build Coastguard Worker _shard_serial(pytest_sharded_tests, sharded_jobs) 179*da0073e9SAndroid Build Coastguard Worker else: 180*da0073e9SAndroid Build Coastguard Worker _shard_parallel(pytest_sharded_tests, sharded_jobs) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker return 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Workerdef calculate_shards( 186*da0073e9SAndroid Build Coastguard Worker num_shards: int, 187*da0073e9SAndroid Build Coastguard Worker tests: Sequence[TestRun], 188*da0073e9SAndroid Build Coastguard Worker test_file_times: dict[str, float], 189*da0073e9SAndroid Build Coastguard Worker test_class_times: dict[str, dict[str, float]] | None, 190*da0073e9SAndroid Build Coastguard Worker must_serial: Callable[[str], bool] | None = None, 191*da0073e9SAndroid Build Coastguard Worker sort_by_time: bool = True, 192*da0073e9SAndroid Build Coastguard Worker) -> list[tuple[float, list[ShardedTest]]]: 193*da0073e9SAndroid Build Coastguard Worker must_serial = must_serial or (lambda x: True) 194*da0073e9SAndroid Build Coastguard Worker test_class_times = test_class_times or {} 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker # Divide tests into pytest shards 197*da0073e9SAndroid Build Coastguard Worker if sort_by_time: 198*da0073e9SAndroid Build Coastguard Worker known_tests = [ 199*da0073e9SAndroid Build Coastguard Worker x 200*da0073e9SAndroid Build Coastguard Worker for x in tests 201*da0073e9SAndroid Build Coastguard Worker if get_duration(x, test_file_times, test_class_times) is not None 202*da0073e9SAndroid Build Coastguard Worker ] 203*da0073e9SAndroid Build Coastguard Worker unknown_tests = [x for x in tests if x not in known_tests] 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker pytest_sharded_tests = sorted( 206*da0073e9SAndroid Build Coastguard Worker get_with_pytest_shard(known_tests, test_file_times, test_class_times), 207*da0073e9SAndroid Build Coastguard Worker key=lambda j: j.get_time(), 208*da0073e9SAndroid Build Coastguard Worker reverse=True, 209*da0073e9SAndroid Build Coastguard Worker ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times) 210*da0073e9SAndroid Build Coastguard Worker else: 211*da0073e9SAndroid Build Coastguard Worker pytest_sharded_tests = get_with_pytest_shard( 212*da0073e9SAndroid Build Coastguard Worker tests, test_file_times, test_class_times 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker del tests 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)] 217*da0073e9SAndroid Build Coastguard Worker parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests] 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker serial_time = sum(test.get_time() for test in serial_tests) 220*da0073e9SAndroid Build Coastguard Worker parallel_time = sum(test.get_time() for test in parallel_tests) 221*da0073e9SAndroid Build Coastguard Worker total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC 222*da0073e9SAndroid Build Coastguard Worker estimated_time_per_shard = total_time / num_shards 223*da0073e9SAndroid Build Coastguard Worker # Separate serial tests from parallel tests as much as possible to maximize 224*da0073e9SAndroid Build Coastguard Worker # parallelism by putting all the serial tests on the first num_serial_shards 225*da0073e9SAndroid Build Coastguard Worker # shards. The estimated_time_limit is the estimated time it should take for 226*da0073e9SAndroid Build Coastguard Worker # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min 227*da0073e9SAndroid Build Coastguard Worker # of parallel tests, 6 shards, and 2 procs per machine, we would expect each 228*da0073e9SAndroid Build Coastguard Worker # machine to take 3 min and should aim for 3 serial shards, with shards 1 229*da0073e9SAndroid Build Coastguard Worker # and 2 taking 3 min and shard 3 taking 2 min. The estimated time limit 230*da0073e9SAndroid Build Coastguard Worker # would be 2 min. This ensures that the first few shard contains as many 231*da0073e9SAndroid Build Coastguard Worker # serial tests as possible and as few parallel tests as possible. The least 232*da0073e9SAndroid Build Coastguard Worker # filled/last (in the example, the 3rd) shard may contain a lot of both 233*da0073e9SAndroid Build Coastguard Worker # serial and parallel tests. 234*da0073e9SAndroid Build Coastguard Worker estimated_time_limit = 0.0 235*da0073e9SAndroid Build Coastguard Worker if estimated_time_per_shard != 0: 236*da0073e9SAndroid Build Coastguard Worker estimated_time_limit = serial_time % estimated_time_per_shard 237*da0073e9SAndroid Build Coastguard Worker if estimated_time_limit <= 0.01: 238*da0073e9SAndroid Build Coastguard Worker estimated_time_limit = estimated_time_per_shard 239*da0073e9SAndroid Build Coastguard Worker if total_time == 0: 240*da0073e9SAndroid Build Coastguard Worker num_serial_shards = num_shards 241*da0073e9SAndroid Build Coastguard Worker else: 242*da0073e9SAndroid Build Coastguard Worker num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker sharded_jobs = [ShardJob() for _ in range(num_shards)] 245*da0073e9SAndroid Build Coastguard Worker shard( 246*da0073e9SAndroid Build Coastguard Worker sharded_jobs=sharded_jobs[:num_serial_shards], 247*da0073e9SAndroid Build Coastguard Worker pytest_sharded_tests=serial_tests, 248*da0073e9SAndroid Build Coastguard Worker estimated_time_limit=estimated_time_limit, 249*da0073e9SAndroid Build Coastguard Worker serial=True, 250*da0073e9SAndroid Build Coastguard Worker ) 251*da0073e9SAndroid Build Coastguard Worker shard( 252*da0073e9SAndroid Build Coastguard Worker sharded_jobs=sharded_jobs, 253*da0073e9SAndroid Build Coastguard Worker pytest_sharded_tests=parallel_tests, 254*da0073e9SAndroid Build Coastguard Worker serial=False, 255*da0073e9SAndroid Build Coastguard Worker ) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker return [job.convert_to_tuple() for job in sharded_jobs] 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Workerdef get_test_case_configs(dirpath: str) -> None: 261*da0073e9SAndroid Build Coastguard Worker get_disabled_tests(dirpath=dirpath) 262