xref: /aosp_15_r20/external/pytorch/tools/test/heuristics/test_interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import sys
4import unittest
5from pathlib import Path
6from typing import Any
7
8
9REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
10sys.path.append(str(REPO_ROOT))
11
12import tools.testing.target_determination.heuristics.interface as interface
13from tools.testing.test_run import TestRun
14
15
16sys.path.remove(str(REPO_ROOT))
17
18
19class TestTD(unittest.TestCase):
20    def assert_test_scores_almost_equal(
21        self, d1: dict[TestRun, float], d2: dict[TestRun, float]
22    ) -> None:
23        # Check that dictionaries are the same, except for floating point errors
24        self.assertEqual(set(d1.keys()), set(d2.keys()))
25        for k, v in d1.items():
26            self.assertAlmostEqual(v, d2[k], msg=f"{k}: {v} != {d2[k]}")
27
28    def make_heuristic(self, classname: str) -> Any:
29        # Create a dummy heuristic class
30        class Heuristic(interface.HeuristicInterface):
31            def get_prediction_confidence(
32                self, tests: list[str]
33            ) -> interface.TestPrioritizations:
34                # Return junk
35                return interface.TestPrioritizations([], {})
36
37        return type(classname, (Heuristic,), {})
38
39
40class TestTestPrioritizations(TestTD):
41    def test_init_none(self) -> None:
42        tests = ["test_a", "test_b"]
43        test_prioritizations = interface.TestPrioritizations(tests, {})
44        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
45        self.assertDictEqual(
46            test_prioritizations._test_scores,
47            {TestRun("test_a"): 0.0, TestRun("test_b"): 0.0},
48        )
49
50    def test_init_set_scores_full_files(self) -> None:
51        tests = ["test_a", "test_b"]
52        test_prioritizations = interface.TestPrioritizations(
53            tests, {TestRun("test_a"): 0.5, TestRun("test_b"): 0.25}
54        )
55        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
56        self.assertDictEqual(
57            test_prioritizations._test_scores,
58            {TestRun("test_a"): 0.5, TestRun("test_b"): 0.25},
59        )
60
61    def test_init_set_scores_some_full_files(self) -> None:
62        tests = ["test_a", "test_b"]
63        test_prioritizations = interface.TestPrioritizations(
64            tests, {TestRun("test_a"): 0.5}
65        )
66        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
67        self.assertDictEqual(
68            test_prioritizations._test_scores,
69            {TestRun("test_a"): 0.5, TestRun("test_b"): 0.0},
70        )
71
72    def test_init_set_scores_classes(self) -> None:
73        tests = ["test_a", "test_b"]
74        test_prioritizations = interface.TestPrioritizations(
75            tests, {TestRun("test_a", included=["TestA"]): 0.5}
76        )
77        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
78        self.assertDictEqual(
79            test_prioritizations._test_scores,
80            {
81                TestRun("test_a", included=["TestA"]): 0.5,
82                TestRun("test_a", excluded=["TestA"]): 0.0,
83                TestRun("test_b"): 0.0,
84            },
85        )
86
87    def test_init_set_scores_other_class_naming_convention(self) -> None:
88        tests = ["test_a", "test_b"]
89        test_prioritizations = interface.TestPrioritizations(
90            tests, {TestRun("test_a::TestA"): 0.5}
91        )
92        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
93        self.assertDictEqual(
94            test_prioritizations._test_scores,
95            {
96                TestRun("test_a", included=["TestA"]): 0.5,
97                TestRun("test_a", excluded=["TestA"]): 0.0,
98                TestRun("test_b"): 0.0,
99            },
100        )
101
102    def test_set_test_score_full_class(self) -> None:
103        tests = ["test_a", "test_b"]
104        test_prioritizations = interface.TestPrioritizations(tests, {})
105        test_prioritizations.set_test_score(TestRun("test_a"), 0.5)
106        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
107        self.assertDictEqual(
108            test_prioritizations._test_scores,
109            {TestRun("test_a"): 0.5, TestRun("test_b"): 0.0},
110        )
111
112    def test_set_test_score_mix(self) -> None:
113        tests = ["test_a", "test_b"]
114        test_prioritizations = interface.TestPrioritizations(
115            tests, {TestRun("test_b"): -0.5}
116        )
117        test_prioritizations.set_test_score(TestRun("test_a"), 0.1)
118        test_prioritizations.set_test_score(TestRun("test_a::TestA"), 0.2)
119        test_prioritizations.set_test_score(TestRun("test_a::TestB"), 0.3)
120        test_prioritizations.set_test_score(TestRun("test_a", included=["TestC"]), 0.4)
121        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
122        self.assertDictEqual(
123            test_prioritizations._test_scores,
124            {
125                TestRun("test_a", included=["TestA"]): 0.2,
126                TestRun("test_a", included=["TestB"]): 0.3,
127                TestRun("test_a", included=["TestC"]): 0.4,
128                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.1,
129                TestRun("test_b"): -0.5,
130            },
131        )
132        test_prioritizations.set_test_score(
133            TestRun("test_a", included=["TestA", "TestB"]), 0.5
134        )
135        self.assertDictEqual(
136            test_prioritizations._test_scores,
137            {
138                TestRun("test_a", included=["TestA", "TestB"]): 0.5,
139                TestRun("test_a", included=["TestC"]): 0.4,
140                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.1,
141                TestRun("test_b"): -0.5,
142            },
143        )
144        test_prioritizations.set_test_score(
145            TestRun("test_a", excluded=["TestA", "TestB"]), 0.6
146        )
147        self.assertDictEqual(
148            test_prioritizations._test_scores,
149            {
150                TestRun("test_a", included=["TestA", "TestB"]): 0.5,
151                TestRun("test_a", excluded=["TestA", "TestB"]): 0.6,
152                TestRun("test_b"): -0.5,
153            },
154        )
155        test_prioritizations.set_test_score(TestRun("test_a", included=["TestC"]), 0.7)
156        self.assertDictEqual(
157            test_prioritizations._test_scores,
158            {
159                TestRun("test_a", included=["TestA", "TestB"]): 0.5,
160                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.6,
161                TestRun("test_a", included=["TestC"]): 0.7,
162                TestRun("test_b"): -0.5,
163            },
164        )
165        test_prioritizations.set_test_score(TestRun("test_a", excluded=["TestD"]), 0.8)
166        self.assertDictEqual(
167            test_prioritizations._test_scores,
168            {
169                TestRun("test_a", excluded=["TestD"]): 0.8,
170                TestRun("test_a", included=["TestD"]): 0.6,
171                TestRun("test_b"): -0.5,
172            },
173        )
174        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
175        test_prioritizations.validate()
176
177    def test_add_test_score_mix(self) -> None:
178        tests = ["test_a", "test_b"]
179        test_prioritizations = interface.TestPrioritizations(
180            tests, {TestRun("test_b"): -0.5}
181        )
182        test_prioritizations.add_test_score(TestRun("test_a"), 0.1)
183        test_prioritizations.add_test_score(TestRun("test_a::TestA"), 0.2)
184        test_prioritizations.add_test_score(TestRun("test_a::TestB"), 0.3)
185        test_prioritizations.add_test_score(TestRun("test_a", included=["TestC"]), 0.4)
186        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
187        self.assert_test_scores_almost_equal(
188            test_prioritizations._test_scores,
189            {
190                TestRun("test_a", included=["TestA"]): 0.3,
191                TestRun("test_a", included=["TestB"]): 0.4,
192                TestRun("test_a", included=["TestC"]): 0.5,
193                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.1,
194                TestRun("test_b"): -0.5,
195            },
196        )
197        test_prioritizations.add_test_score(
198            TestRun("test_a", included=["TestA", "TestB"]), 0.5
199        )
200        self.assert_test_scores_almost_equal(
201            test_prioritizations._test_scores,
202            {
203                TestRun("test_a", included=["TestA"]): 0.8,
204                TestRun("test_a", included=["TestB"]): 0.9,
205                TestRun("test_a", included=["TestC"]): 0.5,
206                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.1,
207                TestRun("test_b"): -0.5,
208            },
209        )
210        test_prioritizations.add_test_score(
211            TestRun("test_a", excluded=["TestA", "TestB"]), 0.6
212        )
213        self.assert_test_scores_almost_equal(
214            test_prioritizations._test_scores,
215            {
216                TestRun("test_a", included=["TestA"]): 0.8,
217                TestRun("test_a", included=["TestB"]): 0.9,
218                TestRun("test_a", included=["TestC"]): 1.1,
219                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.7,
220                TestRun("test_b"): -0.5,
221            },
222        )
223        test_prioritizations.add_test_score(TestRun("test_a", included=["TestC"]), 0.7)
224        self.assert_test_scores_almost_equal(
225            test_prioritizations._test_scores,
226            {
227                TestRun("test_a", included=["TestA"]): 0.8,
228                TestRun("test_a", included=["TestB"]): 0.9,
229                TestRun("test_a", included=["TestC"]): 1.8,
230                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.7,
231                TestRun("test_b"): -0.5,
232            },
233        )
234        test_prioritizations.add_test_score(TestRun("test_a", excluded=["TestD"]), 0.8)
235        self.assert_test_scores_almost_equal(
236            test_prioritizations._test_scores,
237            {
238                TestRun("test_a", included=["TestA"]): 1.6,
239                TestRun("test_a", included=["TestB"]): 1.7,
240                TestRun("test_a", included=["TestC"]): 2.6,
241                TestRun("test_a", included=["TestD"]): 0.7,
242                TestRun("test_a", excluded=["TestA", "TestB", "TestC", "TestD"]): 1.5,
243                TestRun("test_b"): -0.5,
244            },
245        )
246        test_prioritizations.add_test_score(
247            TestRun("test_a", excluded=["TestD", "TestC"]), 0.1
248        )
249        self.assert_test_scores_almost_equal(
250            test_prioritizations._test_scores,
251            {
252                TestRun("test_a", included=["TestA"]): 1.7,
253                TestRun("test_a", included=["TestB"]): 1.8,
254                TestRun("test_a", included=["TestC"]): 2.6,
255                TestRun("test_a", included=["TestD"]): 0.7,
256                TestRun("test_a", excluded=["TestA", "TestB", "TestC", "TestD"]): 1.6,
257                TestRun("test_b"): -0.5,
258            },
259        )
260        self.assertSetEqual(test_prioritizations._original_tests, set(tests))
261        test_prioritizations.validate()
262
263
264class TestAggregatedHeuristics(TestTD):
265    def check(
266        self,
267        tests: list[str],
268        test_prioritizations: list[dict[TestRun, float]],
269        expected: dict[TestRun, float],
270    ) -> None:
271        aggregated_heuristics = interface.AggregatedHeuristics(tests)
272        for i, test_prioritization in enumerate(test_prioritizations):
273            heuristic = self.make_heuristic(f"H{i}")
274            aggregated_heuristics.add_heuristic_results(
275                heuristic(), interface.TestPrioritizations(tests, test_prioritization)
276            )
277        final_prioritzations = aggregated_heuristics.get_aggregated_priorities()
278        self.assert_test_scores_almost_equal(
279            final_prioritzations._test_scores,
280            expected,
281        )
282
283    def test_get_aggregated_priorities_mix_1(self) -> None:
284        tests = ["test_a", "test_b", "test_c"]
285        self.check(
286            tests,
287            [
288                {TestRun("test_a"): 0.5},
289                {TestRun("test_a::TestA"): 0.25},
290                {TestRun("test_c"): 0.8},
291            ],
292            {
293                TestRun("test_a", excluded=["TestA"]): 0.5,
294                TestRun("test_a", included=["TestA"]): 0.75,
295                TestRun("test_b"): 0.0,
296                TestRun("test_c"): 0.8,
297            },
298        )
299
300    def test_get_aggregated_priorities_mix_2(self) -> None:
301        tests = ["test_a", "test_b", "test_c"]
302        self.check(
303            tests,
304            [
305                {
306                    TestRun("test_a", included=["TestC"]): 0.5,
307                    TestRun("test_b"): 0.25,
308                    TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.8,
309                },
310                {
311                    TestRun("test_a::TestA"): 0.25,
312                    TestRun("test_b::TestB"): 0.5,
313                    TestRun("test_a::TestB"): 0.75,
314                    TestRun("test_a", excluded=["TestA", "TestB"]): 0.8,
315                },
316                {TestRun("test_c"): 0.8},
317            ],
318            {
319                TestRun("test_a", included=["TestA"]): 0.25,
320                TestRun("test_a", included=["TestB"]): 0.75,
321                TestRun("test_a", included=["TestC"]): 1.3,
322                TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 1.6,
323                TestRun("test_b", included=["TestB"]): 0.75,
324                TestRun("test_b", excluded=["TestB"]): 0.25,
325                TestRun("test_c"): 0.8,
326            },
327        )
328
329    def test_get_aggregated_priorities_mix_3(self) -> None:
330        tests = ["test_a"]
331        self.check(
332            tests,
333            [
334                {
335                    TestRun("test_a", included=["TestA"]): 0.1,
336                    TestRun("test_a", included=["TestC"]): 0.1,
337                    TestRun("test_a", excluded=["TestA", "TestB", "TestC"]): 0.1,
338                },
339                {
340                    TestRun("test_a", excluded=["TestD"]): 0.1,
341                },
342                {
343                    TestRun("test_a", included=["TestC"]): 0.1,
344                },
345                {
346                    TestRun("test_a", included=["TestB", "TestC"]): 0.1,
347                },
348                {
349                    TestRun("test_a", included=["TestC"]): 0.1,
350                    TestRun("test_a", included=["TestD"]): 0.1,
351                },
352                {
353                    TestRun("test_a"): 0.1,
354                },
355            ],
356            {
357                TestRun("test_a", included=["TestA"]): 0.3,
358                TestRun("test_a", included=["TestB"]): 0.3,
359                TestRun("test_a", included=["TestC"]): 0.6,
360                TestRun("test_a", included=["TestD"]): 0.3,
361                TestRun("test_a", excluded=["TestA", "TestB", "TestC", "TestD"]): 0.3,
362            },
363        )
364
365
366class TestAggregatedHeuristicsTestStats(TestTD):
367    def test_get_test_stats_with_whole_tests(self) -> None:
368        self.maxDiff = None
369        tests = ["test1", "test2", "test3", "test4", "test5"]
370        heuristic1 = interface.TestPrioritizations(
371            tests,
372            {
373                TestRun("test3"): 0.3,
374                TestRun("test4"): 0.1,
375            },
376        )
377        heuristic2 = interface.TestPrioritizations(
378            tests,
379            {
380                TestRun("test5"): 0.5,
381            },
382        )
383
384        aggregator = interface.AggregatedHeuristics(tests)
385        aggregator.add_heuristic_results(self.make_heuristic("H1")(), heuristic1)
386        aggregator.add_heuristic_results(self.make_heuristic("H2")(), heuristic2)
387
388        expected_test3_stats = {
389            "test_name": "test3",
390            "test_filters": "",
391            "heuristics": [
392                {
393                    "position": 0,
394                    "score": 0.3,
395                    "heuristic_name": "H1",
396                    "trial_mode": False,
397                },
398                {
399                    "position": 3,
400                    "score": 0.0,
401                    "heuristic_name": "H2",
402                    "trial_mode": False,
403                },
404            ],
405            "aggregated": {"position": 1, "score": 0.3},
406            "aggregated_trial": {"position": 1, "score": 0.3},
407        }
408
409        test3_stats = aggregator.get_test_stats(TestRun("test3"))
410
411        self.assertDictEqual(test3_stats, expected_test3_stats)
412
413    def test_get_test_stats_only_contains_allowed_types(self) -> None:
414        self.maxDiff = None
415        tests = ["test1", "test2", "test3", "test4", "test5"]
416        heuristic1 = interface.TestPrioritizations(
417            tests,
418            {
419                TestRun("test3"): 0.3,
420                TestRun("test4"): 0.1,
421            },
422        )
423        heuristic2 = interface.TestPrioritizations(
424            tests,
425            {
426                TestRun("test5::classA"): 0.5,
427            },
428        )
429
430        aggregator = interface.AggregatedHeuristics(tests)
431        aggregator.add_heuristic_results(self.make_heuristic("H1")(), heuristic1)
432        aggregator.add_heuristic_results(self.make_heuristic("H2")(), heuristic2)
433
434        stats3 = aggregator.get_test_stats(TestRun("test3"))
435        stats5 = aggregator.get_test_stats(TestRun("test5::classA"))
436
437        def assert_valid_dict(dict_contents: dict[str, Any]) -> None:
438            for key, value in dict_contents.items():
439                self.assertTrue(isinstance(key, str))
440                self.assertTrue(
441                    isinstance(value, (str, float, int, list, dict)),
442                    f"{value} is not a str, float, or dict",
443                )
444                if isinstance(value, dict):
445                    assert_valid_dict(value)
446                elif isinstance(value, list):
447                    for item in value:
448                        assert_valid_dict(item)
449
450        assert_valid_dict(stats3)
451        assert_valid_dict(stats5)
452
453    def test_get_test_stats_gets_rank_for_test_classes(self) -> None:
454        self.maxDiff = None
455        tests = ["test1", "test2", "test3", "test4", "test5"]
456        heuristic1 = interface.TestPrioritizations(
457            tests,
458            {
459                TestRun("test3"): 0.3,
460                TestRun("test4"): 0.1,
461            },
462        )
463        heuristic2 = interface.TestPrioritizations(
464            tests,
465            {
466                TestRun("test5::classA"): 0.5,
467            },
468        )
469
470        aggregator = interface.AggregatedHeuristics(tests)
471        aggregator.add_heuristic_results(self.make_heuristic("H1")(), heuristic1)
472        aggregator.add_heuristic_results(self.make_heuristic("H2")(), heuristic2)
473
474        stats_inclusive = aggregator.get_test_stats(
475            TestRun("test5", included=["classA"])
476        )
477        stats_exclusive = aggregator.get_test_stats(
478            TestRun("test5", excluded=["classA"])
479        )
480        expected_inclusive = {
481            "test_name": "test5",
482            "test_filters": "classA",
483            "heuristics": [
484                {
485                    "position": 4,
486                    "score": 0.0,
487                    "heuristic_name": "H1",
488                    "trial_mode": False,
489                },
490                {
491                    "position": 0,
492                    "score": 0.5,
493                    "heuristic_name": "H2",
494                    "trial_mode": False,
495                },
496            ],
497            "aggregated": {"position": 0, "score": 0.5},
498            "aggregated_trial": {"position": 0, "score": 0.5},
499        }
500        expected_exclusive = {
501            "test_name": "test5",
502            "test_filters": "not (classA)",
503            "heuristics": [
504                {
505                    "position": 4,
506                    "score": 0.0,
507                    "heuristic_name": "H1",
508                    "trial_mode": False,
509                },
510                {
511                    "position": 5,
512                    "score": 0.0,
513                    "heuristic_name": "H2",
514                    "trial_mode": False,
515                },
516            ],
517            "aggregated": {"position": 5, "score": 0.0},
518            "aggregated_trial": {"position": 5, "score": 0.0},
519        }
520
521        self.assertDictEqual(stats_inclusive, expected_inclusive)
522        self.assertDictEqual(stats_exclusive, expected_exclusive)
523
524    def test_get_test_stats_works_with_class_granularity_heuristics(self) -> None:
525        tests = ["test1", "test2", "test3", "test4", "test5"]
526        heuristic1 = interface.TestPrioritizations(
527            tests,
528            {
529                TestRun("test2"): 0.3,
530            },
531        )
532        heuristic2 = interface.TestPrioritizations(
533            tests,
534            {
535                TestRun("test2::TestFooClass"): 0.5,
536            },
537        )
538
539        aggregator = interface.AggregatedHeuristics(tests)
540        aggregator.add_heuristic_results(self.make_heuristic("H1")(), heuristic1)
541        aggregator.add_heuristic_results(self.make_heuristic("H2")(), heuristic2)
542
543        # These should not throw an error
544        aggregator.get_test_stats(TestRun("test2::TestFooClass"))
545        aggregator.get_test_stats(TestRun("test2"))
546
547
548class TestJsonParsing(TestTD):
549    def test_json_parsing_matches_TestPrioritizations(self) -> None:
550        tests = ["test1", "test2", "test3", "test4", "test5"]
551        tp = interface.TestPrioritizations(
552            tests,
553            {
554                TestRun("test3", included=["ClassA"]): 0.8,
555                TestRun("test3", excluded=["ClassA"]): 0.2,
556                TestRun("test4"): 0.7,
557                TestRun("test5"): 0.6,
558            },
559        )
560        tp_json = tp.to_json()
561        tp_json_to_tp = interface.TestPrioritizations.from_json(tp_json)
562
563        self.assertSetEqual(tp._original_tests, tp_json_to_tp._original_tests)
564        self.assertDictEqual(tp._test_scores, tp_json_to_tp._test_scores)
565
566    def test_json_parsing_matches_TestRun(self) -> None:
567        testrun = TestRun("test1", included=["classA", "classB"])
568        testrun_json = testrun.to_json()
569        testrun_json_to_test = TestRun.from_json(testrun_json)
570
571        self.assertTrue(testrun == testrun_json_to_test)
572
573
574if __name__ == "__main__":
575    unittest.main()
576