xref: /aosp_15_r20/external/pytorch/tools/test/test_test_selections.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import functools
4import random
5import sys
6import unittest
7from collections import defaultdict
8from pathlib import Path
9
10
11REPO_ROOT = Path(__file__).resolve().parent.parent.parent
12try:
13    # using tools/ to optimize test run.
14    sys.path.append(str(REPO_ROOT))
15    from tools.testing.test_run import ShardedTest, TestRun
16    from tools.testing.test_selections import calculate_shards, THRESHOLD
17except ModuleNotFoundError:
18    print("Can't import required modules, exiting")
19    sys.exit(1)
20
21
22def gen_class_times(test_times: dict[str, float]) -> dict[str, dict[str, float]]:
23    return {k: {"class1": v} for k, v in test_times.items()}
24
25
26class TestCalculateShards(unittest.TestCase):
27    tests: list[TestRun] = [
28        TestRun("super_long_test"),
29        TestRun("long_test1"),
30        TestRun("long_test2"),
31        TestRun("normal_test1"),
32        TestRun("normal_test2"),
33        TestRun("normal_test3"),
34        TestRun("short_test1"),
35        TestRun("short_test2"),
36        TestRun("short_test3"),
37        TestRun("short_test4"),
38        TestRun("short_test5"),
39    ]
40
41    test_times: dict[str, float] = {
42        "super_long_test": 55,
43        "long_test1": 22,
44        "long_test2": 18,
45        "normal_test1": 9,
46        "normal_test2": 7,
47        "normal_test3": 5,
48        "short_test1": 1,
49        "short_test2": 0.6,
50        "short_test3": 0.4,
51        "short_test4": 0.3,
52        "short_test5": 0.01,
53    }
54
55    test_class_times: dict[str, dict[str, float]] = {
56        "super_long_test": {"class1": 55},
57        "long_test1": {"class1": 1, "class2": 21},
58        "long_test2": {"class1": 10, "class2": 8},
59        "normal_test1": {"class1": 9},
60        "normal_test2": {"class1": 7},
61        "normal_test3": {"class1": 5},
62        "short_test1": {"class1": 1},
63        "short_test2": {"class1": 0.6},
64        "short_test3": {"class1": 0.4},
65        "short_test4": {"class1": 0.3},
66        "short_test5": {"class1": 0.01},
67    }
68
69    def assert_shards_equal(
70        self,
71        expected_shards: list[tuple[float, list[ShardedTest]]],
72        actual_shards: list[tuple[float, list[ShardedTest]]],
73    ) -> None:
74        for expected, actual in zip(expected_shards, actual_shards):
75            self.assertAlmostEqual(expected[0], actual[0])
76            self.assertListEqual(expected[1], actual[1])
77
78    def test_no_times(self) -> None:
79        # Check that round robin sharding is used when no times are provided
80        expected_shards = [
81            (
82                0.0,
83                [
84                    ShardedTest(
85                        test="super_long_test", shard=1, num_shards=1, time=None
86                    ),
87                    ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
88                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
89                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=None),
90                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
91                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
92                ],
93            ),
94            (
95                0.0,
96                [
97                    ShardedTest(test="long_test1", shard=1, num_shards=1, time=None),
98                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=None),
99                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
100                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
101                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
102                ],
103            ),
104        ]
105        self.assert_shards_equal(
106            expected_shards,
107            calculate_shards(2, self.tests, {}, {}, sort_by_time=False),
108        )
109
110    def test_some_times_with_not_sort_by_time(self) -> None:
111        expected_shards = [
112            (
113                400.0,
114                [
115                    ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
116                    ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
117                    ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
118                ],
119            ),
120            (
121                300.0,
122                [
123                    ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
124                    ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
125                ],
126            ),
127        ]
128        self.assert_shards_equal(
129            expected_shards,
130            calculate_shards(
131                2,
132                [
133                    TestRun("test_1"),
134                    TestRun("test_2"),
135                    TestRun("test_3"),
136                    TestRun("test_4"),
137                    TestRun("test_5"),
138                ],
139                {"test_2": 400, "test_3": 300},
140                {},
141                sort_by_time=False,
142            ),
143        )
144
145    def test_serial_parallel_interleaving(self) -> None:
146        expected_shards = [
147            (
148                300.0,
149                [
150                    ShardedTest(test="test_1", shard=1, num_shards=1, time=None),
151                    ShardedTest(test="test_3", shard=1, num_shards=1, time=300),
152                    ShardedTest(test="test_4", shard=1, num_shards=1, time=None),
153                ],
154            ),
155            (
156                400.0,
157                [
158                    ShardedTest(test="test_2", shard=1, num_shards=1, time=400),
159                    ShardedTest(test="test_5", shard=1, num_shards=1, time=None),
160                ],
161            ),
162        ]
163        self.assert_shards_equal(
164            expected_shards,
165            calculate_shards(
166                2,
167                [
168                    TestRun("test_1"),
169                    TestRun("test_2"),
170                    TestRun("test_3"),
171                    TestRun("test_4"),
172                    TestRun("test_5"),
173                ],
174                {"test_2": 400, "test_3": 300},
175                {},
176                must_serial=lambda x: x in ["test_1", "test_3"],
177                sort_by_time=False,
178            ),
179        )
180
181    def test_calculate_2_shards_with_complete_test_times(self) -> None:
182        expected_shards = [
183            (
184                60.0,
185                [
186                    ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
187                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
188                ],
189            ),
190            (
191                58.31,
192                [
193                    ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
194                    ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
195                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
196                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
197                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
198                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
199                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
200                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
201                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
202                ],
203            ),
204        ]
205        self.assert_shards_equal(
206            expected_shards,
207            calculate_shards(2, self.tests, self.test_times, self.test_class_times),
208        )
209
210    def test_calculate_1_shard_with_complete_test_times(self) -> None:
211        tests = self.tests.copy()
212        class_test1 = TestRun("long_test1", excluded=["class2"])
213        class_test2 = TestRun("long_test1", included=["class2"])
214        tests.append(class_test1)
215        tests.append(class_test2)
216
217        expected_shards = [
218            (
219                140.31,
220                [
221                    ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55),
222                    ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
223                    ShardedTest(class_test2, shard=1, num_shards=1, time=21),
224                    ShardedTest(test="long_test2", shard=1, num_shards=1, time=18),
225                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
226                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
227                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
228                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
229                    ShardedTest(class_test1, shard=1, num_shards=1, time=1),
230                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
231                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
232                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
233                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
234                ],
235            )
236        ]
237        self.assert_shards_equal(
238            expected_shards,
239            calculate_shards(1, tests, self.test_times, self.test_class_times),
240        )
241
242    def test_calculate_5_shards_with_complete_test_times(self) -> None:
243        expected_shards = [
244            (
245                55.0,
246                [ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)],
247            ),
248            (22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]),
249            (18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]),
250            (
251                11.31,
252                [
253                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
254                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
255                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6),
256                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4),
257                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3),
258                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01),
259                ],
260            ),
261            (
262                12.0,
263                [
264                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7),
265                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5),
266                ],
267            ),
268        ]
269        self.assert_shards_equal(
270            expected_shards,
271            calculate_shards(5, self.tests, self.test_times, self.test_class_times),
272        )
273
274    def test_calculate_2_shards_with_incomplete_test_times(self) -> None:
275        incomplete_test_times = {
276            k: v for k, v in self.test_times.items() if "test1" in k
277        }
278        expected_shards = [
279            (
280                22.0,
281                [
282                    ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
283                    ShardedTest(
284                        test="super_long_test", shard=1, num_shards=1, time=None
285                    ),
286                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
287                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
288                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
289                ],
290            ),
291            (
292                10.0,
293                [
294                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
295                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
296                    ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
297                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
298                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
299                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
300                ],
301            ),
302        ]
303        self.assert_shards_equal(
304            expected_shards,
305            calculate_shards(
306                2,
307                self.tests,
308                incomplete_test_times,
309                gen_class_times(incomplete_test_times),
310            ),
311        )
312
313    def test_calculate_5_shards_with_incomplete_test_times(self) -> None:
314        incomplete_test_times = {
315            k: v for k, v in self.test_times.items() if "test1" in k
316        }
317        expected_shards = [
318            (
319                22.0,
320                [
321                    ShardedTest(test="long_test1", shard=1, num_shards=1, time=22),
322                    ShardedTest(
323                        test="super_long_test", shard=1, num_shards=1, time=None
324                    ),
325                    ShardedTest(test="short_test3", shard=1, num_shards=1, time=None),
326                ],
327            ),
328            (
329                9.0,
330                [
331                    ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9),
332                    ShardedTest(test="long_test2", shard=1, num_shards=1, time=None),
333                    ShardedTest(test="short_test4", shard=1, num_shards=1, time=None),
334                ],
335            ),
336            (
337                1.0,
338                [
339                    ShardedTest(test="short_test1", shard=1, num_shards=1, time=1),
340                    ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None),
341                    ShardedTest(test="short_test5", shard=1, num_shards=1, time=None),
342                ],
343            ),
344            (
345                0.0,
346                [
347                    ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None),
348                ],
349            ),
350            (
351                0.0,
352                [
353                    ShardedTest(test="short_test2", shard=1, num_shards=1, time=None),
354                ],
355            ),
356        ]
357        self.assert_shards_equal(
358            expected_shards,
359            calculate_shards(
360                5,
361                self.tests,
362                incomplete_test_times,
363                gen_class_times(incomplete_test_times),
364            ),
365        )
366
367    def test_split_shards(self) -> None:
368        test_times: dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD}
369        expected_shards = [
370            (600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]),
371            (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
372        ]
373        self.assert_shards_equal(
374            expected_shards,
375            calculate_shards(
376                2,
377                [TestRun(t) for t in test_times.keys()],
378                test_times,
379                gen_class_times(test_times),
380            ),
381        )
382
383        test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5}
384        expected_shards = [
385            (
386                2200.0,
387                [
388                    ShardedTest(test="test1", shard=1, num_shards=4, time=600.0),
389                    ShardedTest(test="test1", shard=3, num_shards=4, time=600.0),
390                    ShardedTest(test="test2", shard=1, num_shards=3, time=500.0),
391                    ShardedTest(test="test2", shard=3, num_shards=3, time=500.0),
392                ],
393            ),
394            (
395                1700.0,
396                [
397                    ShardedTest(test="test1", shard=2, num_shards=4, time=600.0),
398                    ShardedTest(test="test1", shard=4, num_shards=4, time=600.0),
399                    ShardedTest(test="test2", shard=2, num_shards=3, time=500.0),
400                ],
401            ),
402        ]
403        self.assert_shards_equal(
404            expected_shards,
405            calculate_shards(
406                2,
407                [TestRun(t) for t in test_times.keys()],
408                test_times,
409                gen_class_times(test_times),
410            ),
411        )
412
413        test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD}
414        expected_shards = [
415            (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]),
416            (
417                300.0,
418                [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)],
419            ),
420        ]
421        self.assert_shards_equal(
422            expected_shards,
423            calculate_shards(
424                2,
425                [TestRun(t) for t in test_times.keys()],
426                test_times,
427                gen_class_times(test_times),
428            ),
429        )
430
431    def test_zero_tests(self) -> None:
432        self.assertListEqual([(0.0, []), (0.0, [])], calculate_shards(2, [], {}, None))
433
434    def test_split_shards_random(self) -> None:
435        random.seed(120)
436        for _ in range(100):
437            num_shards = random.randint(1, 10)
438            num_tests = random.randint(1, 100)
439            test_names = [str(i) for i in range(num_tests)]
440            tests = [TestRun(x) for x in test_names]
441            serial = [x for x in test_names if random.randint(0, 1) == 0]
442            has_times = [x for x in test_names if random.randint(0, 1) == 0]
443            random_times: dict[str, float] = {
444                i: random.randint(0, THRESHOLD * 10) for i in has_times
445            }
446            sort_by_time = random.randint(0, 1) == 0
447
448            shards = calculate_shards(
449                num_shards,
450                tests,
451                random_times,
452                None,
453                must_serial=lambda x: x in serial,
454                sort_by_time=sort_by_time,
455            )
456
457            times = [x[0] for x in shards]
458            max_diff = max(times) - min(times)
459            self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60)
460
461            all_sharded_tests: dict[str, list[ShardedTest]] = defaultdict(list)
462            for _, sharded_tests in shards:
463                for sharded_test in sharded_tests:
464                    all_sharded_tests[sharded_test.name].append(sharded_test)
465
466            # Check that all test files are represented in the shards
467            self.assertListEqual(sorted(test_names), sorted(all_sharded_tests.keys()))
468            # Check that for each test file, the pytest shards' times adds up to
469            # original and all shards are present
470            for test, sharded_tests in all_sharded_tests.items():
471                if random_times.get(test) is None:
472                    self.assertTrue(len(sharded_tests) == 1)
473                    self.assertTrue(sharded_tests[0].time is None)
474                else:
475                    # x.time is not None because of the above check
476                    self.assertAlmostEqual(
477                        random_times[test], sum(x.time for x in sharded_tests)  # type: ignore[misc]
478                    )
479                self.assertListEqual(
480                    list(range(sharded_tests[0].num_shards)),
481                    sorted(x.shard - 1 for x in sharded_tests),
482                )
483            # Check that sort_by_time is respected
484            if sort_by_time:
485
486                def comparator(a: ShardedTest, b: ShardedTest) -> int:
487                    # serial comes first
488                    if a.name in serial and b.name not in serial:
489                        return -1
490                    if a.name not in serial and b.name in serial:
491                        return 1
492                    # known test times come first
493                    if a.time is not None and b.time is None:
494                        return -1
495                    if a.time is None and b.time is not None:
496                        return 1
497                    if a.time == b.time:
498                        return 0
499                    # not None due to the above checks
500                    return -1 if a.time > b.time else 1  # type: ignore[operator]
501
502            else:
503
504                def comparator(a: ShardedTest, b: ShardedTest) -> int:
505                    # serial comes first
506                    if a.name in serial and b.name not in serial:
507                        return -1
508                    if a.name not in serial and b.name in serial:
509                        return 1
510                    return test_names.index(a.name) - test_names.index(b.name)
511
512            for _, sharded_tests in shards:
513                self.assertListEqual(
514                    sorted(sharded_tests, key=functools.cmp_to_key(comparator)),
515                    sharded_tests,
516                )
517
518    def test_calculate_2_shards_against_optimal_shards(self) -> None:
519        random.seed(120)
520        for _ in range(100):
521            random_times = {k.test_file: random.random() * 10 for k in self.tests}
522            # all test times except first two
523            rest_of_tests = [
524                i
525                for k, i in random_times.items()
526                if k != "super_long_test" and k != "long_test1"
527            ]
528            sum_of_rest = sum(rest_of_tests)
529            random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests)
530            random_times["long_test1"] = sum_of_rest - random_times["super_long_test"]
531            # An optimal sharding would look like the below, but we don't need to compute this for the test:
532            # optimal_shards = [
533            #     (sum_of_rest, ['super_long_test', 'long_test1']),
534            #     (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']),
535            # ]
536            calculated_shards = calculate_shards(
537                2, self.tests, random_times, gen_class_times(random_times)
538            )
539            max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0])
540            if sum_of_rest != 0:
541                # The calculated shard should not have a ratio worse than 7/6 for num_shards = 2
542                self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest)
543                sorted_tests = sorted([t.test_file for t in self.tests])
544                sorted_shard_tests = sorted(
545                    calculated_shards[0][1] + calculated_shards[1][1]
546                )
547                # All the tests should be represented by some shard
548                self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests])
549
550
551if __name__ == "__main__":
552    unittest.main()
553