xref: /aosp_15_r20/external/pytorch/tools/testing/test_selections.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport math
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport subprocess
6*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Sequence
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom tools.stats.import_test_stats import get_disabled_tests
10*da0073e9SAndroid Build Coastguard Workerfrom tools.testing.test_run import ShardedTest, TestRun
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).resolve().parent.parent.parent
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard WorkerIS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1"
16*da0073e9SAndroid Build Coastguard WorkerBUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
17*da0073e9SAndroid Build Coastguard WorkerUSE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker# NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job
20*da0073e9SAndroid Build Coastguard Worker# to ensure that sharding is consistent, NUM_PROCS is the actual number of procs
21*da0073e9SAndroid Build Coastguard Worker# used to run tests.  If they are not equal, the only consequence should be
22*da0073e9SAndroid Build Coastguard Worker# unequal shards.
23*da0073e9SAndroid Build Coastguard WorkerIS_ROCM = os.path.exists("/opt/rocm")
24*da0073e9SAndroid Build Coastguard WorkerNUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2
25*da0073e9SAndroid Build Coastguard WorkerNUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2
26*da0073e9SAndroid Build Coastguard WorkerTHRESHOLD = 60 * 10  # 10 minutes
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker# See Note [ROCm parallel CI testing]
29*da0073e9SAndroid Build Coastguard Worker# Special logic for ROCm GHA runners to query number of GPUs available.
30*da0073e9SAndroid Build Coastguard Worker# torch.version.hip was not available to check if this was a ROCm self-hosted runner.
31*da0073e9SAndroid Build Coastguard Worker# Must check for ROCm runner in another way. We look for /opt/rocm directory.
32*da0073e9SAndroid Build Coastguard Workerif IS_ROCM and not IS_MEM_LEAK_CHECK:
33*da0073e9SAndroid Build Coastguard Worker    try:
34*da0073e9SAndroid Build Coastguard Worker        # This is the same logic used in GHA health check, see .github/templates/common.yml.j2
35*da0073e9SAndroid Build Coastguard Worker        lines = (
36*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n")
37*da0073e9SAndroid Build Coastguard Worker        )
38*da0073e9SAndroid Build Coastguard Worker        count = 0
39*da0073e9SAndroid Build Coastguard Worker        for line in lines:
40*da0073e9SAndroid Build Coastguard Worker            if " gfx" in line:
41*da0073e9SAndroid Build Coastguard Worker                count += 1
42*da0073e9SAndroid Build Coastguard Worker        assert count > 0  # there must be at least 1 GPU
43*da0073e9SAndroid Build Coastguard Worker        # Limiting to 8 GPUs(PROCS)
44*da0073e9SAndroid Build Coastguard Worker        NUM_PROCS = min(count, 8)
45*da0073e9SAndroid Build Coastguard Worker    except subprocess.CalledProcessError as e:
46*da0073e9SAndroid Build Coastguard Worker        # The safe default for ROCm GHA runners is to run tests serially.
47*da0073e9SAndroid Build Coastguard Worker        NUM_PROCS = 1
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclass ShardJob:
51*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
52*da0073e9SAndroid Build Coastguard Worker        self.serial: list[ShardedTest] = []
53*da0073e9SAndroid Build Coastguard Worker        self.parallel: list[ShardedTest] = []
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def get_total_time(self) -> float:
56*da0073e9SAndroid Build Coastguard Worker        """Default is the value for which to substitute if a test has no time"""
57*da0073e9SAndroid Build Coastguard Worker        procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)]
58*da0073e9SAndroid Build Coastguard Worker        for test in self.parallel:
59*da0073e9SAndroid Build Coastguard Worker            min_index = procs.index(min(procs))
60*da0073e9SAndroid Build Coastguard Worker            procs[min_index] += test.get_time()
61*da0073e9SAndroid Build Coastguard Worker        time = max(procs) + sum(test.get_time() for test in self.serial)
62*da0073e9SAndroid Build Coastguard Worker        return time
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]:
65*da0073e9SAndroid Build Coastguard Worker        return (self.get_total_time(), self.serial + self.parallel)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerdef get_with_pytest_shard(
69*da0073e9SAndroid Build Coastguard Worker    tests: Sequence[TestRun],
70*da0073e9SAndroid Build Coastguard Worker    test_file_times: dict[str, float],
71*da0073e9SAndroid Build Coastguard Worker    test_class_times: dict[str, dict[str, float]] | None,
72*da0073e9SAndroid Build Coastguard Worker) -> list[ShardedTest]:
73*da0073e9SAndroid Build Coastguard Worker    sharded_tests: list[ShardedTest] = []
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    for test in tests:
76*da0073e9SAndroid Build Coastguard Worker        duration = get_duration(test, test_file_times, test_class_times or {})
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        if duration and duration > THRESHOLD:
79*da0073e9SAndroid Build Coastguard Worker            num_shards = math.ceil(duration / THRESHOLD)
80*da0073e9SAndroid Build Coastguard Worker            for i in range(num_shards):
81*da0073e9SAndroid Build Coastguard Worker                sharded_tests.append(
82*da0073e9SAndroid Build Coastguard Worker                    ShardedTest(test, i + 1, num_shards, duration / num_shards)
83*da0073e9SAndroid Build Coastguard Worker                )
84*da0073e9SAndroid Build Coastguard Worker        else:
85*da0073e9SAndroid Build Coastguard Worker            sharded_tests.append(ShardedTest(test, 1, 1, duration))
86*da0073e9SAndroid Build Coastguard Worker    return sharded_tests
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerdef get_duration(
90*da0073e9SAndroid Build Coastguard Worker    test: TestRun,
91*da0073e9SAndroid Build Coastguard Worker    test_file_times: dict[str, float],
92*da0073e9SAndroid Build Coastguard Worker    test_class_times: dict[str, dict[str, float]],
93*da0073e9SAndroid Build Coastguard Worker) -> float | None:
94*da0073e9SAndroid Build Coastguard Worker    """Calculate the time for a TestRun based on the given test_file_times and
95*da0073e9SAndroid Build Coastguard Worker    test_class_times.  Returns None if the time is unknown."""
96*da0073e9SAndroid Build Coastguard Worker    file_duration = test_file_times.get(test.test_file, None)
97*da0073e9SAndroid Build Coastguard Worker    if test.is_full_file():
98*da0073e9SAndroid Build Coastguard Worker        return file_duration
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def get_duration_for_classes(
101*da0073e9SAndroid Build Coastguard Worker        test_file: str, test_classes: frozenset[str]
102*da0073e9SAndroid Build Coastguard Worker    ) -> float | None:
103*da0073e9SAndroid Build Coastguard Worker        duration: float = 0
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        for test_class in test_classes:
106*da0073e9SAndroid Build Coastguard Worker            class_duration = test_class_times.get(test_file, {}).get(test_class, None)
107*da0073e9SAndroid Build Coastguard Worker            if class_duration is None:
108*da0073e9SAndroid Build Coastguard Worker                return None
109*da0073e9SAndroid Build Coastguard Worker            duration += class_duration
110*da0073e9SAndroid Build Coastguard Worker        return duration
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    included = test.included()
113*da0073e9SAndroid Build Coastguard Worker    excluded = test.excluded()
114*da0073e9SAndroid Build Coastguard Worker    included_classes_duration = get_duration_for_classes(test.test_file, included)
115*da0073e9SAndroid Build Coastguard Worker    excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    if included_classes_duration is None or excluded_classes_duration is None:
118*da0073e9SAndroid Build Coastguard Worker        # Didn't get the time for all classes, so time is unknown
119*da0073e9SAndroid Build Coastguard Worker        return None
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    if included:
122*da0073e9SAndroid Build Coastguard Worker        return included_classes_duration
123*da0073e9SAndroid Build Coastguard Worker    assert (
124*da0073e9SAndroid Build Coastguard Worker        excluded
125*da0073e9SAndroid Build Coastguard Worker    ), f"TestRun {test} is not full file but doesn't have included or excluded classes"
126*da0073e9SAndroid Build Coastguard Worker    if file_duration is None:
127*da0073e9SAndroid Build Coastguard Worker        return None
128*da0073e9SAndroid Build Coastguard Worker    return file_duration - excluded_classes_duration
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Workerdef shard(
132*da0073e9SAndroid Build Coastguard Worker    sharded_jobs: list[ShardJob],
133*da0073e9SAndroid Build Coastguard Worker    pytest_sharded_tests: Sequence[ShardedTest],
134*da0073e9SAndroid Build Coastguard Worker    estimated_time_limit: float | None = None,
135*da0073e9SAndroid Build Coastguard Worker    serial: bool = False,
136*da0073e9SAndroid Build Coastguard Worker) -> None:
137*da0073e9SAndroid Build Coastguard Worker    # Modifies sharded_jobs in place
138*da0073e9SAndroid Build Coastguard Worker    if len(sharded_jobs) == 0:
139*da0073e9SAndroid Build Coastguard Worker        assert (
140*da0073e9SAndroid Build Coastguard Worker            len(pytest_sharded_tests) == 0
141*da0073e9SAndroid Build Coastguard Worker        ), "No shards provided but there are tests to shard"
142*da0073e9SAndroid Build Coastguard Worker        return
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    round_robin_index = 0
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def _get_min_sharded_job(
147*da0073e9SAndroid Build Coastguard Worker        sharded_jobs: list[ShardJob], test: ShardedTest
148*da0073e9SAndroid Build Coastguard Worker    ) -> ShardJob:
149*da0073e9SAndroid Build Coastguard Worker        if test.time is None:
150*da0073e9SAndroid Build Coastguard Worker            nonlocal round_robin_index
151*da0073e9SAndroid Build Coastguard Worker            job = sharded_jobs[round_robin_index % len(sharded_jobs)]
152*da0073e9SAndroid Build Coastguard Worker            round_robin_index += 1
153*da0073e9SAndroid Build Coastguard Worker            return job
154*da0073e9SAndroid Build Coastguard Worker        return min(sharded_jobs, key=lambda j: j.get_total_time())
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    def _shard_serial(
157*da0073e9SAndroid Build Coastguard Worker        tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob]
158*da0073e9SAndroid Build Coastguard Worker    ) -> None:
159*da0073e9SAndroid Build Coastguard Worker        assert estimated_time_limit is not None, "Estimated time limit must be provided"
160*da0073e9SAndroid Build Coastguard Worker        new_sharded_jobs = sharded_jobs
161*da0073e9SAndroid Build Coastguard Worker        for test in tests:
162*da0073e9SAndroid Build Coastguard Worker            if (
163*da0073e9SAndroid Build Coastguard Worker                len(sharded_jobs) > 1
164*da0073e9SAndroid Build Coastguard Worker                and sharded_jobs[-1].get_total_time() > estimated_time_limit
165*da0073e9SAndroid Build Coastguard Worker            ):
166*da0073e9SAndroid Build Coastguard Worker                new_sharded_jobs = sharded_jobs[:-1]
167*da0073e9SAndroid Build Coastguard Worker            min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test)
168*da0073e9SAndroid Build Coastguard Worker            min_sharded_job.serial.append(test)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def _shard_parallel(
171*da0073e9SAndroid Build Coastguard Worker        tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob]
172*da0073e9SAndroid Build Coastguard Worker    ) -> None:
173*da0073e9SAndroid Build Coastguard Worker        for test in tests:
174*da0073e9SAndroid Build Coastguard Worker            min_sharded_job = _get_min_sharded_job(sharded_jobs, test)
175*da0073e9SAndroid Build Coastguard Worker            min_sharded_job.parallel.append(test)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    if serial:
178*da0073e9SAndroid Build Coastguard Worker        _shard_serial(pytest_sharded_tests, sharded_jobs)
179*da0073e9SAndroid Build Coastguard Worker    else:
180*da0073e9SAndroid Build Coastguard Worker        _shard_parallel(pytest_sharded_tests, sharded_jobs)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    return
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Workerdef calculate_shards(
186*da0073e9SAndroid Build Coastguard Worker    num_shards: int,
187*da0073e9SAndroid Build Coastguard Worker    tests: Sequence[TestRun],
188*da0073e9SAndroid Build Coastguard Worker    test_file_times: dict[str, float],
189*da0073e9SAndroid Build Coastguard Worker    test_class_times: dict[str, dict[str, float]] | None,
190*da0073e9SAndroid Build Coastguard Worker    must_serial: Callable[[str], bool] | None = None,
191*da0073e9SAndroid Build Coastguard Worker    sort_by_time: bool = True,
192*da0073e9SAndroid Build Coastguard Worker) -> list[tuple[float, list[ShardedTest]]]:
193*da0073e9SAndroid Build Coastguard Worker    must_serial = must_serial or (lambda x: True)
194*da0073e9SAndroid Build Coastguard Worker    test_class_times = test_class_times or {}
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    # Divide tests into pytest shards
197*da0073e9SAndroid Build Coastguard Worker    if sort_by_time:
198*da0073e9SAndroid Build Coastguard Worker        known_tests = [
199*da0073e9SAndroid Build Coastguard Worker            x
200*da0073e9SAndroid Build Coastguard Worker            for x in tests
201*da0073e9SAndroid Build Coastguard Worker            if get_duration(x, test_file_times, test_class_times) is not None
202*da0073e9SAndroid Build Coastguard Worker        ]
203*da0073e9SAndroid Build Coastguard Worker        unknown_tests = [x for x in tests if x not in known_tests]
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        pytest_sharded_tests = sorted(
206*da0073e9SAndroid Build Coastguard Worker            get_with_pytest_shard(known_tests, test_file_times, test_class_times),
207*da0073e9SAndroid Build Coastguard Worker            key=lambda j: j.get_time(),
208*da0073e9SAndroid Build Coastguard Worker            reverse=True,
209*da0073e9SAndroid Build Coastguard Worker        ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times)
210*da0073e9SAndroid Build Coastguard Worker    else:
211*da0073e9SAndroid Build Coastguard Worker        pytest_sharded_tests = get_with_pytest_shard(
212*da0073e9SAndroid Build Coastguard Worker            tests, test_file_times, test_class_times
213*da0073e9SAndroid Build Coastguard Worker        )
214*da0073e9SAndroid Build Coastguard Worker    del tests
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)]
217*da0073e9SAndroid Build Coastguard Worker    parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests]
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    serial_time = sum(test.get_time() for test in serial_tests)
220*da0073e9SAndroid Build Coastguard Worker    parallel_time = sum(test.get_time() for test in parallel_tests)
221*da0073e9SAndroid Build Coastguard Worker    total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC
222*da0073e9SAndroid Build Coastguard Worker    estimated_time_per_shard = total_time / num_shards
223*da0073e9SAndroid Build Coastguard Worker    # Separate serial tests from parallel tests as much as possible to maximize
224*da0073e9SAndroid Build Coastguard Worker    # parallelism by putting all the serial tests on the first num_serial_shards
225*da0073e9SAndroid Build Coastguard Worker    # shards. The estimated_time_limit is the estimated time it should take for
226*da0073e9SAndroid Build Coastguard Worker    # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min
227*da0073e9SAndroid Build Coastguard Worker    # of parallel tests, 6 shards, and 2 procs per machine, we would expect each
228*da0073e9SAndroid Build Coastguard Worker    # machine to take 3 min and should aim for 3 serial shards, with shards 1
229*da0073e9SAndroid Build Coastguard Worker    # and 2 taking 3 min and shard 3 taking 2 min.  The estimated time limit
230*da0073e9SAndroid Build Coastguard Worker    # would be 2 min. This ensures that the first few shard contains as many
231*da0073e9SAndroid Build Coastguard Worker    # serial tests as possible and as few parallel tests as possible. The least
232*da0073e9SAndroid Build Coastguard Worker    # filled/last (in the example, the 3rd) shard may contain a lot of both
233*da0073e9SAndroid Build Coastguard Worker    # serial and parallel tests.
234*da0073e9SAndroid Build Coastguard Worker    estimated_time_limit = 0.0
235*da0073e9SAndroid Build Coastguard Worker    if estimated_time_per_shard != 0:
236*da0073e9SAndroid Build Coastguard Worker        estimated_time_limit = serial_time % estimated_time_per_shard
237*da0073e9SAndroid Build Coastguard Worker    if estimated_time_limit <= 0.01:
238*da0073e9SAndroid Build Coastguard Worker        estimated_time_limit = estimated_time_per_shard
239*da0073e9SAndroid Build Coastguard Worker    if total_time == 0:
240*da0073e9SAndroid Build Coastguard Worker        num_serial_shards = num_shards
241*da0073e9SAndroid Build Coastguard Worker    else:
242*da0073e9SAndroid Build Coastguard Worker        num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    sharded_jobs = [ShardJob() for _ in range(num_shards)]
245*da0073e9SAndroid Build Coastguard Worker    shard(
246*da0073e9SAndroid Build Coastguard Worker        sharded_jobs=sharded_jobs[:num_serial_shards],
247*da0073e9SAndroid Build Coastguard Worker        pytest_sharded_tests=serial_tests,
248*da0073e9SAndroid Build Coastguard Worker        estimated_time_limit=estimated_time_limit,
249*da0073e9SAndroid Build Coastguard Worker        serial=True,
250*da0073e9SAndroid Build Coastguard Worker    )
251*da0073e9SAndroid Build Coastguard Worker    shard(
252*da0073e9SAndroid Build Coastguard Worker        sharded_jobs=sharded_jobs,
253*da0073e9SAndroid Build Coastguard Worker        pytest_sharded_tests=parallel_tests,
254*da0073e9SAndroid Build Coastguard Worker        serial=False,
255*da0073e9SAndroid Build Coastguard Worker    )
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    return [job.convert_to_tuple() for job in sharded_jobs]
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Workerdef get_test_case_configs(dirpath: str) -> None:
261*da0073e9SAndroid Build Coastguard Worker    get_disabled_tests(dirpath=dirpath)
262