xref: /aosp_15_r20/external/pytorch/tools/stats/import_test_stats.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import datetime
4import json
5import os
6import pathlib
7import shutil
8from typing import Any, Callable, cast, Dict, List, Optional, Union
9from urllib.request import urlopen
10
11REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent
12
13
14def get_disabled_issues() -> List[str]:
15    reenabled_issues = os.getenv("REENABLED_ISSUES", "")
16    issue_numbers = reenabled_issues.split(",")
17    print("Ignoring disabled issues: ", issue_numbers)
18    return issue_numbers
19
20
21SLOW_TESTS_FILE = ".pytorch-slow-tests.json"
22DISABLED_TESTS_FILE = ".pytorch-disabled-tests.json"
23ADDITIONAL_CI_FILES_FOLDER = pathlib.Path(".additional_ci_files")
24TEST_TIMES_FILE = "test-times.json"
25TEST_CLASS_TIMES_FILE = "test-class-times.json"
26TEST_FILE_RATINGS_FILE = "test-file-ratings.json"
27TEST_CLASS_RATINGS_FILE = "test-class-ratings.json"
28TD_HEURISTIC_PROFILING_FILE = "td_heuristic_profiling.json"
29TD_HEURISTIC_HISTORICAL_EDITED_FILES = "td_heuristic_historical_edited_files.json"
30TD_HEURISTIC_PREVIOUSLY_FAILED = "previous_failures.json"
31TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL = "previous_failures_additional.json"
32
33FILE_CACHE_LIFESPAN_SECONDS = datetime.timedelta(hours=3).seconds
34
35
36def fetch_and_cache(
37    dirpath: Union[str, pathlib.Path],
38    name: str,
39    url: str,
40    process_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
41) -> Dict[str, Any]:
42    """
43    This fetch and cache utils allows sharing between different process.
44    """
45    pathlib.Path(dirpath).mkdir(exist_ok=True)
46
47    path = os.path.join(dirpath, name)
48    print(f"Downloading {url} to {path}")
49
50    def is_cached_file_valid() -> bool:
51        # Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
52        # could make a HEAD request and check/store the file's ETag
53        fname = pathlib.Path(path)
54        now = datetime.datetime.now()
55        mtime = datetime.datetime.fromtimestamp(fname.stat().st_mtime)
56        diff = now - mtime
57        return diff.total_seconds() < FILE_CACHE_LIFESPAN_SECONDS
58
59    if os.path.exists(path) and is_cached_file_valid():
60        # Another test process already download the file, so don't re-do it
61        with open(path) as f:
62            return cast(Dict[str, Any], json.load(f))
63
64    for _ in range(3):
65        try:
66            contents = urlopen(url, timeout=5).read().decode("utf-8")
67            processed_contents = process_fn(json.loads(contents))
68            with open(path, "w") as f:
69                f.write(json.dumps(processed_contents))
70            return processed_contents
71        except Exception as e:
72            print(f"Could not download {url} because: {e}.")
73    print(f"All retries exhausted, downloading {url} failed.")
74    return {}
75
76
77def get_slow_tests(
78    dirpath: str, filename: str = SLOW_TESTS_FILE
79) -> Optional[Dict[str, float]]:
80    url = "https://ossci-metrics.s3.amazonaws.com/slow-tests.json?versionId=oKMp2dsjwgbtvuXJrL9fZbQiSJkiw91I"
81    try:
82        return fetch_and_cache(dirpath, filename, url, lambda x: x)
83    except Exception:
84        print("Couldn't download slow test set, leaving all tests enabled...")
85        return {}
86
87
88def get_test_times() -> Dict[str, Dict[str, float]]:
89    return get_from_test_infra_generated_stats(
90        "test-times.json",
91        TEST_TIMES_FILE,
92        "Couldn't download test times...",
93    )
94
95
96def get_test_class_times() -> Dict[str, Dict[str, float]]:
97    return get_from_test_infra_generated_stats(
98        "test-class-times.json",
99        TEST_CLASS_TIMES_FILE,
100        "Couldn't download test times...",
101    )
102
103
104def get_disabled_tests(
105    dirpath: str, filename: str = DISABLED_TESTS_FILE
106) -> Optional[Dict[str, Any]]:
107    def process_disabled_test(the_response: Dict[str, Any]) -> Dict[str, Any]:
108        # remove re-enabled tests and condense even further by getting rid of pr_num
109        disabled_issues = get_disabled_issues()
110        disabled_test_from_issues = dict()
111        for test_name, (pr_num, link, platforms) in the_response.items():
112            if pr_num not in disabled_issues:
113                disabled_test_from_issues[test_name] = (
114                    link,
115                    platforms,
116                )
117        return disabled_test_from_issues
118
119    try:
120        url = "https://ossci-metrics.s3.amazonaws.com/disabled-tests-condensed.json?versionId=0zzD6gFqZ9l2Vs1SYtjHXSuVLz6BaLtE"
121        return fetch_and_cache(dirpath, filename, url, process_disabled_test)
122    except Exception:
123        print("Couldn't download test skip set, leaving all tests enabled...")
124        return {}
125
126
127def get_test_file_ratings() -> Dict[str, Any]:
128    return get_from_test_infra_generated_stats(
129        "file_test_rating.json",
130        TEST_FILE_RATINGS_FILE,
131        "Couldn't download test file ratings file, not reordering...",
132    )
133
134
135def get_test_class_ratings() -> Dict[str, Any]:
136    return get_from_test_infra_generated_stats(
137        "file_test_class_rating.json",
138        TEST_CLASS_RATINGS_FILE,
139        "Couldn't download test class ratings file, not reordering...",
140    )
141
142
143def get_td_heuristic_historial_edited_files_json() -> Dict[str, Any]:
144    return get_from_test_infra_generated_stats(
145        "td_heuristic_historical_edited_files.json",
146        TD_HEURISTIC_HISTORICAL_EDITED_FILES,
147        "Couldn't download td_heuristic_historical_edited_files.json, not reordering...",
148    )
149
150
151def get_td_heuristic_profiling_json() -> Dict[str, Any]:
152    return get_from_test_infra_generated_stats(
153        "td_heuristic_profiling.json",
154        TD_HEURISTIC_PROFILING_FILE,
155        "Couldn't download td_heuristic_profiling.json not reordering...",
156    )
157
158
159def copy_pytest_cache() -> None:
160    original_path = REPO_ROOT / ".pytest_cache/v/cache/lastfailed"
161    if not original_path.exists():
162        return
163    shutil.copyfile(
164        original_path,
165        REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER / TD_HEURISTIC_PREVIOUSLY_FAILED,
166    )
167
168
169def copy_additional_previous_failures() -> None:
170    original_path = (
171        REPO_ROOT / ".pytest_cache" / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL
172    )
173    if not original_path.exists():
174        return
175    shutil.copyfile(
176        original_path,
177        REPO_ROOT
178        / ADDITIONAL_CI_FILES_FOLDER
179        / TD_HEURISTIC_PREVIOUSLY_FAILED_ADDITIONAL,
180    )
181
182
183def get_from_test_infra_generated_stats(
184    from_file: str, to_file: str, failure_explanation: str
185) -> Dict[str, Any]:
186    url = f"https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/{from_file}"
187    try:
188        return fetch_and_cache(
189            REPO_ROOT / ADDITIONAL_CI_FILES_FOLDER, to_file, url, lambda x: x
190        )
191    except Exception:
192        print(failure_explanation)
193        return {}
194