xref: /aosp_15_r20/external/pytorch/tools/test/heuristics/test_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import sys
4import unittest
5from pathlib import Path
6from typing import Any
7
8
9REPO_ROOT = Path(__file__).resolve().parent.parent.parent.parent
10sys.path.append(str(REPO_ROOT))
11
12import tools.testing.target_determination.heuristics.utils as utils
13from tools.testing.test_run import TestRun
14
15
16sys.path.remove(str(REPO_ROOT))
17
18
19class TestHeuristicsUtils(unittest.TestCase):
20    def assertDictAlmostEqual(
21        self, first: dict[TestRun, Any], second: dict[TestRun, Any]
22    ) -> None:
23        self.assertEqual(first.keys(), second.keys())
24        for key in first.keys():
25            self.assertAlmostEqual(first[key], second[key])
26
27    def test_normalize_ratings(self) -> None:
28        ratings: dict[TestRun, float] = {
29            TestRun("test1"): 1,
30            TestRun("test2"): 2,
31            TestRun("test3"): 4,
32        }
33        normalized = utils.normalize_ratings(ratings, 4)
34        self.assertDictAlmostEqual(normalized, ratings)
35
36        normalized = utils.normalize_ratings(ratings, 0.1)
37        self.assertDictAlmostEqual(
38            normalized,
39            {
40                TestRun("test1"): 0.025,
41                TestRun("test2"): 0.05,
42                TestRun("test3"): 0.1,
43            },
44        )
45
46        normalized = utils.normalize_ratings(ratings, 0.2, min_value=0.1)
47        self.assertDictAlmostEqual(
48            normalized,
49            {
50                TestRun("test1"): 0.125,
51                TestRun("test2"): 0.15,
52                TestRun("test3"): 0.2,
53            },
54        )
55
56
57if __name__ == "__main__":
58    unittest.main()
59