xref: /aosp_15_r20/external/pytorch/tools/testing/target_determination/heuristics/filepath.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from collections import defaultdict
4from functools import lru_cache
5from pathlib import Path
6from typing import Any, Callable
7from warnings import warn
8
9from tools.testing.target_determination.heuristics.interface import (
10    HeuristicInterface,
11    TestPrioritizations,
12)
13from tools.testing.target_determination.heuristics.utils import (
14    normalize_ratings,
15    query_changed_files,
16)
17from tools.testing.test_run import TestRun
18
19
20REPO_ROOT = Path(__file__).parent.parent.parent.parent
21
22keyword_synonyms: dict[str, list[str]] = {
23    "amp": ["mixed_precision"],
24    "quant": ["quantized", "quantization", "quantize"],
25    "decomp": ["decomposition", "decompositions"],
26    "numpy": ["torch_np", "numpy_tests"],
27    "ops": ["opinfo"],
28    "hop": ["higher_order_op"],
29    "aot": ["flex_attention", "autograd"],
30    "inductor": ["dynamo", "export"],  # not actually synonyms but they interact a lot
31}
32
33not_keyword = [
34    "torch",
35    "test",
36    "tests",
37    "util",
38    "utils",
39    "func",
40    "src",
41    "c",
42    "ns",
43    "tools",
44    "internal",
45]
46
47custom_matchers: dict[str, Callable[[str], bool]] = {
48    "nn": lambda x: "nn" in x.replace("onnx", "_"),
49    "c10": lambda x: "c10" in x.replace("c10d", "_"),
50}
51
52
53@lru_cache(maxsize=1)
54def get_keywords(file: str) -> list[str]:
55    keywords = []
56    for folder in Path(file).parts[:-1]:
57        folder = sanitize_folder_name(folder)
58        keywords.append(folder)
59    return [kw for kw in keywords if kw not in not_keyword]
60
61
62def sanitize_folder_name(folder_name: str) -> str:
63    if folder_name.startswith("_"):
64        folder_name = folder_name[1:]
65
66    for syn_rep, syns in keyword_synonyms.items():
67        if folder_name in syns or folder_name == syn_rep:
68            return syn_rep
69
70    return folder_name
71
72
73def file_matches_keyword(file: str, keyword: str) -> bool:
74    keywords = get_keywords(file)
75    return (
76        keyword in keywords
77        or any(
78            syn in keywords or syn in file for syn in keyword_synonyms.get(keyword, [])
79        )
80        or custom_matchers.get(keyword, lambda x: keyword in x)(file)  # type: ignore[no-untyped-call]
81    )
82
83
84class Filepath(HeuristicInterface):
85    # Heuristic based on folders in the file path.  Takes each folder of each
86    # changed file and attempts to find matches based on those folders
87    def __init__(self, **kwargs: dict[str, Any]) -> None:
88        super().__init__(**kwargs)
89
90    def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations:
91        keyword_frequency: dict[str, int] = defaultdict(int)
92        try:
93            changed_files = query_changed_files()
94        except Exception as e:
95            warn(f"Can't query changed test files due to {e}")
96            changed_files = []
97
98        for cf in changed_files:
99            keywords = get_keywords(cf)
100            for keyword in keywords:
101                keyword_frequency[keyword] += 1
102
103        test_ratings: dict[str, float] = defaultdict(float)
104
105        for test in tests:
106            for keyword, frequency in keyword_frequency.items():
107                if file_matches_keyword(test, keyword):
108                    test_ratings[test] += frequency
109        test_ratings = {TestRun(k): v for (k, v) in test_ratings.items() if k in tests}
110        return TestPrioritizations(
111            tests, normalize_ratings(test_ratings, 0.25, min_value=0.125)
112        )
113