xref: /aosp_15_r20/external/pytorch/tools/stats/check_disabled_tests.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import json
5import os
6import xml.etree.ElementTree as ET
7from pathlib import Path
8from tempfile import TemporaryDirectory
9from typing import Any, Generator
10
11from tools.stats.upload_stats_lib import (
12    download_s3_artifacts,
13    is_rerun_disabled_tests,
14    unzip,
15    upload_workflow_stats_to_s3,
16)
17from tools.stats.upload_test_stats import process_xml_element
18
19
20TESTCASE_TAG = "testcase"
21SEPARATOR = ";"
22
23
24def process_report(
25    report: Path,
26) -> dict[str, dict[str, int]]:
27    """
28    Return a list of disabled tests that should be re-enabled and those that are still
29    flaky (failed or skipped)
30    """
31    root = ET.parse(report)
32
33    # All rerun tests from a report are grouped here:
34    #
35    # * Success test should be re-enable if it's green after rerunning in all platforms
36    #   where it is currently disabled
37    # * Failures from pytest because pytest-flakefinder is used to run the same test
38    #   multiple times, some could fails
39    # * Skipped tests from unittest
40    #
41    # We want to keep track of how many times the test fails (num_red) or passes (num_green)
42    all_tests: dict[str, dict[str, int]] = {}
43
44    for test_case in root.iter(TESTCASE_TAG):
45        parsed_test_case = process_xml_element(test_case)
46
47        # Under --rerun-disabled-tests mode, a test is skipped when:
48        # * it's skipped explicitly inside PyTorch code
49        # * it's skipped because it's a normal enabled test
50        # * or it's falky (num_red > 0 and num_green > 0)
51        # * or it's failing (num_red > 0 and num_green == 0)
52        #
53        # We care only about the latter two here
54        skipped = parsed_test_case.get("skipped", None)
55
56        # NB: Regular ONNX tests could return a list of subskips here where each item in the
57        # list is a skipped message.  In the context of rerunning disabled tests, we could
58        # ignore this case as returning a list of subskips only happens when tests are run
59        # normally
60        if skipped and (
61            type(skipped) is list or "num_red" not in skipped.get("message", "")
62        ):
63            continue
64
65        name = parsed_test_case.get("name", "")
66        classname = parsed_test_case.get("classname", "")
67        filename = parsed_test_case.get("file", "")
68
69        if not name or not classname or not filename:
70            continue
71
72        # Check if the test is a failure
73        failure = parsed_test_case.get("failure", None)
74
75        disabled_test_id = SEPARATOR.join([name, classname, filename])
76        if disabled_test_id not in all_tests:
77            all_tests[disabled_test_id] = {
78                "num_green": 0,
79                "num_red": 0,
80            }
81
82        # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's
83        # counted as a success. Otherwise, it's still flaky or failing
84        if skipped:
85            try:
86                stats = json.loads(skipped.get("message", ""))
87            except json.JSONDecodeError:
88                stats = {}
89
90            all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0)
91            all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0)
92        elif failure:
93            # As a failure, increase the failure count
94            all_tests[disabled_test_id]["num_red"] += 1
95        else:
96            all_tests[disabled_test_id]["num_green"] += 1
97
98    return all_tests
99
100
101def get_test_reports(
102    repo: str, workflow_run_id: int, workflow_run_attempt: int
103) -> Generator[Path, None, None]:
104    """
105    Gather all the test reports from S3 and GHA. It is currently not possible to guess which
106    test reports are from rerun_disabled_tests workflow because the name doesn't include the
107    test config. So, all reports will need to be downloaded and examined
108    """
109    with TemporaryDirectory() as temp_dir:
110        print("Using temporary directory:", temp_dir)
111        os.chdir(temp_dir)
112
113        artifact_paths = download_s3_artifacts(
114            "test-reports", workflow_run_id, workflow_run_attempt
115        )
116        for path in artifact_paths:
117            unzip(path)
118
119        yield from Path(".").glob("**/*.xml")
120
121
122def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
123    """
124    Follow flaky bot convention here, if that changes, this will also need to be updated
125    """
126    name, classname, filename = test_id.split(SEPARATOR)
127    return f"{name} (__main__.{classname})", name, classname, filename
128
129
130def prepare_record(
131    workflow_id: int,
132    workflow_run_attempt: int,
133    name: str,
134    classname: str,
135    filename: str,
136    flaky: bool,
137    num_red: int = 0,
138    num_green: int = 0,
139) -> tuple[Any, dict[str, Any]]:
140    """
141    Prepare the record to save onto S3
142    """
143    key = (
144        workflow_id,
145        workflow_run_attempt,
146        name,
147        classname,
148        filename,
149    )
150
151    record = {
152        "workflow_id": workflow_id,
153        "workflow_run_attempt": workflow_run_attempt,
154        "name": name,
155        "classname": classname,
156        "filename": filename,
157        "flaky": flaky,
158        "num_green": num_green,
159        "num_red": num_red,
160    }
161
162    return key, record
163
164
165def save_results(
166    workflow_id: int,
167    workflow_run_attempt: int,
168    all_tests: dict[str, dict[str, int]],
169) -> None:
170    """
171    Save the result to S3, so it can go to Rockset
172    """
173    should_be_enabled_tests = {
174        name: stats
175        for name, stats in all_tests.items()
176        if "num_green" in stats
177        and stats["num_green"]
178        and "num_red" in stats
179        and stats["num_red"] == 0
180    }
181    still_flaky_tests = {
182        name: stats
183        for name, stats in all_tests.items()
184        if name not in should_be_enabled_tests
185    }
186
187    records = {}
188    for test_id, stats in all_tests.items():
189        num_green = stats.get("num_green", 0)
190        num_red = stats.get("num_red", 0)
191        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
192
193        key, record = prepare_record(
194            workflow_id=workflow_id,
195            workflow_run_attempt=workflow_run_attempt,
196            name=name,
197            classname=classname,
198            filename=filename,
199            flaky=test_id in still_flaky_tests,
200            num_green=num_green,
201            num_red=num_red,
202        )
203        records[key] = record
204
205    # Log the results
206    print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:")
207    for test_id, stats in should_be_enabled_tests.items():
208        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
209        print(f"  {disabled_test_name} from {filename}")
210
211    print(f"The following {len(still_flaky_tests)} are still flaky:")
212    for test_id, stats in still_flaky_tests.items():
213        num_green = stats.get("num_green", 0)
214        num_red = stats.get("num_red", 0)
215
216        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
217        print(
218            f"  {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}"
219        )
220
221    upload_workflow_stats_to_s3(
222        workflow_id,
223        workflow_run_attempt,
224        "rerun_disabled_tests",
225        list(records.values()),
226    )
227
228
229def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
230    """
231    Find the list of all disabled tests that should be re-enabled
232    """
233    # Aggregated across all jobs
234    all_tests: dict[str, dict[str, int]] = {}
235
236    for report in get_test_reports(
237        args.repo, args.workflow_run_id, args.workflow_run_attempt
238    ):
239        tests = process_report(report)
240
241        # The scheduled workflow has both rerun disabled tests and memory leak check jobs.
242        # We are only interested in the former here
243        if not is_rerun_disabled_tests(tests):
244            continue
245
246        for name, stats in tests.items():
247            if name not in all_tests:
248                all_tests[name] = stats.copy()
249            else:
250                all_tests[name]["num_green"] += stats.get("num_green", 0)
251                all_tests[name]["num_red"] += stats.get("num_red", 0)
252
253    save_results(
254        workflow_run_id,
255        workflow_run_attempt,
256        all_tests,
257    )
258
259
260if __name__ == "__main__":
261    parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3")
262    parser.add_argument(
263        "--workflow-run-id",
264        type=int,
265        required=True,
266        help="id of the workflow to get artifacts from",
267    )
268    parser.add_argument(
269        "--workflow-run-attempt",
270        type=int,
271        required=True,
272        help="which retry of the workflow this is",
273    )
274    parser.add_argument(
275        "--repo",
276        type=str,
277        required=True,
278        help="which GitHub repo this workflow run belongs to",
279    )
280
281    args = parser.parse_args()
282    main(args.repo, args.workflow_run_id, args.workflow_run_attempt)
283