1from __future__ import annotations 2 3import functools 4import random 5import sys 6import unittest 7from collections import defaultdict 8from pathlib import Path 9 10 11REPO_ROOT = Path(__file__).resolve().parent.parent.parent 12try: 13 # using tools/ to optimize test run. 14 sys.path.append(str(REPO_ROOT)) 15 from tools.testing.test_run import ShardedTest, TestRun 16 from tools.testing.test_selections import calculate_shards, THRESHOLD 17except ModuleNotFoundError: 18 print("Can't import required modules, exiting") 19 sys.exit(1) 20 21 22def gen_class_times(test_times: dict[str, float]) -> dict[str, dict[str, float]]: 23 return {k: {"class1": v} for k, v in test_times.items()} 24 25 26class TestCalculateShards(unittest.TestCase): 27 tests: list[TestRun] = [ 28 TestRun("super_long_test"), 29 TestRun("long_test1"), 30 TestRun("long_test2"), 31 TestRun("normal_test1"), 32 TestRun("normal_test2"), 33 TestRun("normal_test3"), 34 TestRun("short_test1"), 35 TestRun("short_test2"), 36 TestRun("short_test3"), 37 TestRun("short_test4"), 38 TestRun("short_test5"), 39 ] 40 41 test_times: dict[str, float] = { 42 "super_long_test": 55, 43 "long_test1": 22, 44 "long_test2": 18, 45 "normal_test1": 9, 46 "normal_test2": 7, 47 "normal_test3": 5, 48 "short_test1": 1, 49 "short_test2": 0.6, 50 "short_test3": 0.4, 51 "short_test4": 0.3, 52 "short_test5": 0.01, 53 } 54 55 test_class_times: dict[str, dict[str, float]] = { 56 "super_long_test": {"class1": 55}, 57 "long_test1": {"class1": 1, "class2": 21}, 58 "long_test2": {"class1": 10, "class2": 8}, 59 "normal_test1": {"class1": 9}, 60 "normal_test2": {"class1": 7}, 61 "normal_test3": {"class1": 5}, 62 "short_test1": {"class1": 1}, 63 "short_test2": {"class1": 0.6}, 64 "short_test3": {"class1": 0.4}, 65 "short_test4": {"class1": 0.3}, 66 "short_test5": {"class1": 0.01}, 67 } 68 69 def assert_shards_equal( 70 self, 71 expected_shards: list[tuple[float, list[ShardedTest]]], 72 actual_shards: list[tuple[float, list[ShardedTest]]], 73 ) -> None: 74 for expected, actual in zip(expected_shards, actual_shards): 75 self.assertAlmostEqual(expected[0], actual[0]) 76 self.assertListEqual(expected[1], actual[1]) 77 78 def test_no_times(self) -> None: 79 # Check that round robin sharding is used when no times are provided 80 expected_shards = [ 81 ( 82 0.0, 83 [ 84 ShardedTest( 85 test="super_long_test", shard=1, num_shards=1, time=None 86 ), 87 ShardedTest(test="long_test2", shard=1, num_shards=1, time=None), 88 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None), 89 ShardedTest(test="short_test1", shard=1, num_shards=1, time=None), 90 ShardedTest(test="short_test3", shard=1, num_shards=1, time=None), 91 ShardedTest(test="short_test5", shard=1, num_shards=1, time=None), 92 ], 93 ), 94 ( 95 0.0, 96 [ 97 ShardedTest(test="long_test1", shard=1, num_shards=1, time=None), 98 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=None), 99 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None), 100 ShardedTest(test="short_test2", shard=1, num_shards=1, time=None), 101 ShardedTest(test="short_test4", shard=1, num_shards=1, time=None), 102 ], 103 ), 104 ] 105 self.assert_shards_equal( 106 expected_shards, 107 calculate_shards(2, self.tests, {}, {}, sort_by_time=False), 108 ) 109 110 def test_some_times_with_not_sort_by_time(self) -> None: 111 expected_shards = [ 112 ( 113 400.0, 114 [ 115 ShardedTest(test="test_1", shard=1, num_shards=1, time=None), 116 ShardedTest(test="test_2", shard=1, num_shards=1, time=400), 117 ShardedTest(test="test_5", shard=1, num_shards=1, time=None), 118 ], 119 ), 120 ( 121 300.0, 122 [ 123 ShardedTest(test="test_3", shard=1, num_shards=1, time=300), 124 ShardedTest(test="test_4", shard=1, num_shards=1, time=None), 125 ], 126 ), 127 ] 128 self.assert_shards_equal( 129 expected_shards, 130 calculate_shards( 131 2, 132 [ 133 TestRun("test_1"), 134 TestRun("test_2"), 135 TestRun("test_3"), 136 TestRun("test_4"), 137 TestRun("test_5"), 138 ], 139 {"test_2": 400, "test_3": 300}, 140 {}, 141 sort_by_time=False, 142 ), 143 ) 144 145 def test_serial_parallel_interleaving(self) -> None: 146 expected_shards = [ 147 ( 148 300.0, 149 [ 150 ShardedTest(test="test_1", shard=1, num_shards=1, time=None), 151 ShardedTest(test="test_3", shard=1, num_shards=1, time=300), 152 ShardedTest(test="test_4", shard=1, num_shards=1, time=None), 153 ], 154 ), 155 ( 156 400.0, 157 [ 158 ShardedTest(test="test_2", shard=1, num_shards=1, time=400), 159 ShardedTest(test="test_5", shard=1, num_shards=1, time=None), 160 ], 161 ), 162 ] 163 self.assert_shards_equal( 164 expected_shards, 165 calculate_shards( 166 2, 167 [ 168 TestRun("test_1"), 169 TestRun("test_2"), 170 TestRun("test_3"), 171 TestRun("test_4"), 172 TestRun("test_5"), 173 ], 174 {"test_2": 400, "test_3": 300}, 175 {}, 176 must_serial=lambda x: x in ["test_1", "test_3"], 177 sort_by_time=False, 178 ), 179 ) 180 181 def test_calculate_2_shards_with_complete_test_times(self) -> None: 182 expected_shards = [ 183 ( 184 60.0, 185 [ 186 ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55), 187 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), 188 ], 189 ), 190 ( 191 58.31, 192 [ 193 ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), 194 ShardedTest(test="long_test2", shard=1, num_shards=1, time=18), 195 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), 196 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), 197 ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), 198 ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), 199 ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), 200 ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), 201 ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), 202 ], 203 ), 204 ] 205 self.assert_shards_equal( 206 expected_shards, 207 calculate_shards(2, self.tests, self.test_times, self.test_class_times), 208 ) 209 210 def test_calculate_1_shard_with_complete_test_times(self) -> None: 211 tests = self.tests.copy() 212 class_test1 = TestRun("long_test1", excluded=["class2"]) 213 class_test2 = TestRun("long_test1", included=["class2"]) 214 tests.append(class_test1) 215 tests.append(class_test2) 216 217 expected_shards = [ 218 ( 219 140.31, 220 [ 221 ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55), 222 ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), 223 ShardedTest(class_test2, shard=1, num_shards=1, time=21), 224 ShardedTest(test="long_test2", shard=1, num_shards=1, time=18), 225 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), 226 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), 227 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), 228 ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), 229 ShardedTest(class_test1, shard=1, num_shards=1, time=1), 230 ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), 231 ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), 232 ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), 233 ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), 234 ], 235 ) 236 ] 237 self.assert_shards_equal( 238 expected_shards, 239 calculate_shards(1, tests, self.test_times, self.test_class_times), 240 ) 241 242 def test_calculate_5_shards_with_complete_test_times(self) -> None: 243 expected_shards = [ 244 ( 245 55.0, 246 [ShardedTest(test="super_long_test", shard=1, num_shards=1, time=55)], 247 ), 248 (22.0, [ShardedTest(test="long_test1", shard=1, num_shards=1, time=22)]), 249 (18.0, [ShardedTest(test="long_test2", shard=1, num_shards=1, time=18)]), 250 ( 251 11.31, 252 [ 253 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), 254 ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), 255 ShardedTest(test="short_test2", shard=1, num_shards=1, time=0.6), 256 ShardedTest(test="short_test3", shard=1, num_shards=1, time=0.4), 257 ShardedTest(test="short_test4", shard=1, num_shards=1, time=0.3), 258 ShardedTest(test="short_test5", shard=1, num_shards=1, time=0.01), 259 ], 260 ), 261 ( 262 12.0, 263 [ 264 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=7), 265 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=5), 266 ], 267 ), 268 ] 269 self.assert_shards_equal( 270 expected_shards, 271 calculate_shards(5, self.tests, self.test_times, self.test_class_times), 272 ) 273 274 def test_calculate_2_shards_with_incomplete_test_times(self) -> None: 275 incomplete_test_times = { 276 k: v for k, v in self.test_times.items() if "test1" in k 277 } 278 expected_shards = [ 279 ( 280 22.0, 281 [ 282 ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), 283 ShardedTest( 284 test="super_long_test", shard=1, num_shards=1, time=None 285 ), 286 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None), 287 ShardedTest(test="short_test2", shard=1, num_shards=1, time=None), 288 ShardedTest(test="short_test4", shard=1, num_shards=1, time=None), 289 ], 290 ), 291 ( 292 10.0, 293 [ 294 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), 295 ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), 296 ShardedTest(test="long_test2", shard=1, num_shards=1, time=None), 297 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None), 298 ShardedTest(test="short_test3", shard=1, num_shards=1, time=None), 299 ShardedTest(test="short_test5", shard=1, num_shards=1, time=None), 300 ], 301 ), 302 ] 303 self.assert_shards_equal( 304 expected_shards, 305 calculate_shards( 306 2, 307 self.tests, 308 incomplete_test_times, 309 gen_class_times(incomplete_test_times), 310 ), 311 ) 312 313 def test_calculate_5_shards_with_incomplete_test_times(self) -> None: 314 incomplete_test_times = { 315 k: v for k, v in self.test_times.items() if "test1" in k 316 } 317 expected_shards = [ 318 ( 319 22.0, 320 [ 321 ShardedTest(test="long_test1", shard=1, num_shards=1, time=22), 322 ShardedTest( 323 test="super_long_test", shard=1, num_shards=1, time=None 324 ), 325 ShardedTest(test="short_test3", shard=1, num_shards=1, time=None), 326 ], 327 ), 328 ( 329 9.0, 330 [ 331 ShardedTest(test="normal_test1", shard=1, num_shards=1, time=9), 332 ShardedTest(test="long_test2", shard=1, num_shards=1, time=None), 333 ShardedTest(test="short_test4", shard=1, num_shards=1, time=None), 334 ], 335 ), 336 ( 337 1.0, 338 [ 339 ShardedTest(test="short_test1", shard=1, num_shards=1, time=1), 340 ShardedTest(test="normal_test2", shard=1, num_shards=1, time=None), 341 ShardedTest(test="short_test5", shard=1, num_shards=1, time=None), 342 ], 343 ), 344 ( 345 0.0, 346 [ 347 ShardedTest(test="normal_test3", shard=1, num_shards=1, time=None), 348 ], 349 ), 350 ( 351 0.0, 352 [ 353 ShardedTest(test="short_test2", shard=1, num_shards=1, time=None), 354 ], 355 ), 356 ] 357 self.assert_shards_equal( 358 expected_shards, 359 calculate_shards( 360 5, 361 self.tests, 362 incomplete_test_times, 363 gen_class_times(incomplete_test_times), 364 ), 365 ) 366 367 def test_split_shards(self) -> None: 368 test_times: dict[str, float] = {"test1": THRESHOLD, "test2": THRESHOLD} 369 expected_shards = [ 370 (600.0, [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD)]), 371 (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]), 372 ] 373 self.assert_shards_equal( 374 expected_shards, 375 calculate_shards( 376 2, 377 [TestRun(t) for t in test_times.keys()], 378 test_times, 379 gen_class_times(test_times), 380 ), 381 ) 382 383 test_times = {"test1": THRESHOLD * 4, "test2": THRESHOLD * 2.5} 384 expected_shards = [ 385 ( 386 2200.0, 387 [ 388 ShardedTest(test="test1", shard=1, num_shards=4, time=600.0), 389 ShardedTest(test="test1", shard=3, num_shards=4, time=600.0), 390 ShardedTest(test="test2", shard=1, num_shards=3, time=500.0), 391 ShardedTest(test="test2", shard=3, num_shards=3, time=500.0), 392 ], 393 ), 394 ( 395 1700.0, 396 [ 397 ShardedTest(test="test1", shard=2, num_shards=4, time=600.0), 398 ShardedTest(test="test1", shard=4, num_shards=4, time=600.0), 399 ShardedTest(test="test2", shard=2, num_shards=3, time=500.0), 400 ], 401 ), 402 ] 403 self.assert_shards_equal( 404 expected_shards, 405 calculate_shards( 406 2, 407 [TestRun(t) for t in test_times.keys()], 408 test_times, 409 gen_class_times(test_times), 410 ), 411 ) 412 413 test_times = {"test1": THRESHOLD / 2, "test2": THRESHOLD} 414 expected_shards = [ 415 (600.0, [ShardedTest(test="test2", shard=1, num_shards=1, time=THRESHOLD)]), 416 ( 417 300.0, 418 [ShardedTest(test="test1", shard=1, num_shards=1, time=THRESHOLD / 2)], 419 ), 420 ] 421 self.assert_shards_equal( 422 expected_shards, 423 calculate_shards( 424 2, 425 [TestRun(t) for t in test_times.keys()], 426 test_times, 427 gen_class_times(test_times), 428 ), 429 ) 430 431 def test_zero_tests(self) -> None: 432 self.assertListEqual([(0.0, []), (0.0, [])], calculate_shards(2, [], {}, None)) 433 434 def test_split_shards_random(self) -> None: 435 random.seed(120) 436 for _ in range(100): 437 num_shards = random.randint(1, 10) 438 num_tests = random.randint(1, 100) 439 test_names = [str(i) for i in range(num_tests)] 440 tests = [TestRun(x) for x in test_names] 441 serial = [x for x in test_names if random.randint(0, 1) == 0] 442 has_times = [x for x in test_names if random.randint(0, 1) == 0] 443 random_times: dict[str, float] = { 444 i: random.randint(0, THRESHOLD * 10) for i in has_times 445 } 446 sort_by_time = random.randint(0, 1) == 0 447 448 shards = calculate_shards( 449 num_shards, 450 tests, 451 random_times, 452 None, 453 must_serial=lambda x: x in serial, 454 sort_by_time=sort_by_time, 455 ) 456 457 times = [x[0] for x in shards] 458 max_diff = max(times) - min(times) 459 self.assertTrue(max_diff <= THRESHOLD + (num_tests - len(has_times)) * 60) 460 461 all_sharded_tests: dict[str, list[ShardedTest]] = defaultdict(list) 462 for _, sharded_tests in shards: 463 for sharded_test in sharded_tests: 464 all_sharded_tests[sharded_test.name].append(sharded_test) 465 466 # Check that all test files are represented in the shards 467 self.assertListEqual(sorted(test_names), sorted(all_sharded_tests.keys())) 468 # Check that for each test file, the pytest shards' times adds up to 469 # original and all shards are present 470 for test, sharded_tests in all_sharded_tests.items(): 471 if random_times.get(test) is None: 472 self.assertTrue(len(sharded_tests) == 1) 473 self.assertTrue(sharded_tests[0].time is None) 474 else: 475 # x.time is not None because of the above check 476 self.assertAlmostEqual( 477 random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc] 478 ) 479 self.assertListEqual( 480 list(range(sharded_tests[0].num_shards)), 481 sorted(x.shard - 1 for x in sharded_tests), 482 ) 483 # Check that sort_by_time is respected 484 if sort_by_time: 485 486 def comparator(a: ShardedTest, b: ShardedTest) -> int: 487 # serial comes first 488 if a.name in serial and b.name not in serial: 489 return -1 490 if a.name not in serial and b.name in serial: 491 return 1 492 # known test times come first 493 if a.time is not None and b.time is None: 494 return -1 495 if a.time is None and b.time is not None: 496 return 1 497 if a.time == b.time: 498 return 0 499 # not None due to the above checks 500 return -1 if a.time > b.time else 1 # type: ignore[operator] 501 502 else: 503 504 def comparator(a: ShardedTest, b: ShardedTest) -> int: 505 # serial comes first 506 if a.name in serial and b.name not in serial: 507 return -1 508 if a.name not in serial and b.name in serial: 509 return 1 510 return test_names.index(a.name) - test_names.index(b.name) 511 512 for _, sharded_tests in shards: 513 self.assertListEqual( 514 sorted(sharded_tests, key=functools.cmp_to_key(comparator)), 515 sharded_tests, 516 ) 517 518 def test_calculate_2_shards_against_optimal_shards(self) -> None: 519 random.seed(120) 520 for _ in range(100): 521 random_times = {k.test_file: random.random() * 10 for k in self.tests} 522 # all test times except first two 523 rest_of_tests = [ 524 i 525 for k, i in random_times.items() 526 if k != "super_long_test" and k != "long_test1" 527 ] 528 sum_of_rest = sum(rest_of_tests) 529 random_times["super_long_test"] = max(sum_of_rest / 2, *rest_of_tests) 530 random_times["long_test1"] = sum_of_rest - random_times["super_long_test"] 531 # An optimal sharding would look like the below, but we don't need to compute this for the test: 532 # optimal_shards = [ 533 # (sum_of_rest, ['super_long_test', 'long_test1']), 534 # (sum_of_rest, [i for i in self.tests if i != 'super_long_test' and i != 'long_test1']), 535 # ] 536 calculated_shards = calculate_shards( 537 2, self.tests, random_times, gen_class_times(random_times) 538 ) 539 max_shard_time = max(calculated_shards[0][0], calculated_shards[1][0]) 540 if sum_of_rest != 0: 541 # The calculated shard should not have a ratio worse than 7/6 for num_shards = 2 542 self.assertGreaterEqual(7.0 / 6.0, max_shard_time / sum_of_rest) 543 sorted_tests = sorted([t.test_file for t in self.tests]) 544 sorted_shard_tests = sorted( 545 calculated_shards[0][1] + calculated_shards[1][1] 546 ) 547 # All the tests should be represented by some shard 548 self.assertEqual(sorted_tests, [x.name for x in sorted_shard_tests]) 549 550 551if __name__ == "__main__": 552 unittest.main() 553