1from __future__ import annotations 2 3from collections import defaultdict 4from functools import lru_cache 5from pathlib import Path 6from typing import Any, Callable 7from warnings import warn 8 9from tools.testing.target_determination.heuristics.interface import ( 10 HeuristicInterface, 11 TestPrioritizations, 12) 13from tools.testing.target_determination.heuristics.utils import ( 14 normalize_ratings, 15 query_changed_files, 16) 17from tools.testing.test_run import TestRun 18 19 20REPO_ROOT = Path(__file__).parent.parent.parent.parent 21 22keyword_synonyms: dict[str, list[str]] = { 23 "amp": ["mixed_precision"], 24 "quant": ["quantized", "quantization", "quantize"], 25 "decomp": ["decomposition", "decompositions"], 26 "numpy": ["torch_np", "numpy_tests"], 27 "ops": ["opinfo"], 28 "hop": ["higher_order_op"], 29 "aot": ["flex_attention", "autograd"], 30 "inductor": ["dynamo", "export"], # not actually synonyms but they interact a lot 31} 32 33not_keyword = [ 34 "torch", 35 "test", 36 "tests", 37 "util", 38 "utils", 39 "func", 40 "src", 41 "c", 42 "ns", 43 "tools", 44 "internal", 45] 46 47custom_matchers: dict[str, Callable[[str], bool]] = { 48 "nn": lambda x: "nn" in x.replace("onnx", "_"), 49 "c10": lambda x: "c10" in x.replace("c10d", "_"), 50} 51 52 53@lru_cache(maxsize=1) 54def get_keywords(file: str) -> list[str]: 55 keywords = [] 56 for folder in Path(file).parts[:-1]: 57 folder = sanitize_folder_name(folder) 58 keywords.append(folder) 59 return [kw for kw in keywords if kw not in not_keyword] 60 61 62def sanitize_folder_name(folder_name: str) -> str: 63 if folder_name.startswith("_"): 64 folder_name = folder_name[1:] 65 66 for syn_rep, syns in keyword_synonyms.items(): 67 if folder_name in syns or folder_name == syn_rep: 68 return syn_rep 69 70 return folder_name 71 72 73def file_matches_keyword(file: str, keyword: str) -> bool: 74 keywords = get_keywords(file) 75 return ( 76 keyword in keywords 77 or any( 78 syn in keywords or syn in file for syn in keyword_synonyms.get(keyword, []) 79 ) 80 or custom_matchers.get(keyword, lambda x: keyword in x)(file) # type: ignore[no-untyped-call] 81 ) 82 83 84class Filepath(HeuristicInterface): 85 # Heuristic based on folders in the file path. Takes each folder of each 86 # changed file and attempts to find matches based on those folders 87 def __init__(self, **kwargs: dict[str, Any]) -> None: 88 super().__init__(**kwargs) 89 90 def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: 91 keyword_frequency: dict[str, int] = defaultdict(int) 92 try: 93 changed_files = query_changed_files() 94 except Exception as e: 95 warn(f"Can't query changed test files due to {e}") 96 changed_files = [] 97 98 for cf in changed_files: 99 keywords = get_keywords(cf) 100 for keyword in keywords: 101 keyword_frequency[keyword] += 1 102 103 test_ratings: dict[str, float] = defaultdict(float) 104 105 for test in tests: 106 for keyword, frequency in keyword_frequency.items(): 107 if file_matches_keyword(test, keyword): 108 test_ratings[test] += frequency 109 test_ratings = {TestRun(k): v for (k, v) in test_ratings.items() if k in tests} 110 return TestPrioritizations( 111 tests, normalize_ratings(test_ratings, 0.25, min_value=0.125) 112 ) 113