xref: /aosp_15_r20/external/pytorch/scripts/compile_tests/update_failures.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2import argparse
3import os
4import subprocess
5from pathlib import Path
6
7from common import (
8    get_testcases,
9    is_failure,
10    is_passing_skipped_test,
11    is_unexpected_success,
12    key,
13    open_test_results,
14)
15from download_reports import download_reports
16
17
18"""
19Usage: update_failures.py /path/to/dynamo_test_failures.py /path/to/test commit_sha
20
21Best-effort updates the xfail and skip files under test directory
22by parsing test reports.
23
24You'll need to provide the commit_sha for the latest commit on a PR
25from which we will pull CI test results.
26
27Instructions:
28- On your PR, add the "keep-going" label to ensure that all the tests are
29  failing (as opposed to CI stopping on the first failure). You may need to
30  restart your test jobs by force-pushing to your branch for CI to pick
31  up the "keep-going" label.
32- Wait for all the tests to finish running.
33- Find the full SHA of your commit and run this command.
34
35This script requires the `gh` cli. You'll need to install it and then
36authenticate with it via `gh auth login` before using this script.
37https://docs.github.com/en/github-cli/github-cli/quickstart
38"""
39
40
41def patch_file(
42    filename, test_dir, unexpected_successes, new_xfails, new_skips, unexpected_skips
43):
44    failures_directory = os.path.join(test_dir, "dynamo_expected_failures")
45    skips_directory = os.path.join(test_dir, "dynamo_skips")
46
47    dynamo_expected_failures = set(os.listdir(failures_directory))
48    dynamo_skips = set(os.listdir(skips_directory))
49
50    # These are hand written skips
51    extra_dynamo_skips = set()
52    with open(filename) as f:
53        start = False
54        for text in f.readlines():
55            text = text.strip()
56            if start:
57                if text == "}":
58                    break
59                extra_dynamo_skips.add(text.strip(',"'))
60            else:
61                if text == "extra_dynamo_skips = {":
62                    start = True
63
64    def format(testcase):
65        classname = testcase.attrib["classname"]
66        name = testcase.attrib["name"]
67        return f"{classname}.{name}"
68
69    formatted_unexpected_successes = {
70        f"{format(test)}" for test in unexpected_successes.values()
71    }
72    formatted_unexpected_skips = {
73        f"{format(test)}" for test in unexpected_skips.values()
74    }
75    formatted_new_xfails = [f"{format(test)}" for test in new_xfails.values()]
76    formatted_new_skips = [f"{format(test)}" for test in new_skips.values()]
77
78    def remove_file(path, name):
79        file = os.path.join(path, name)
80        cmd = ["git", "rm", file]
81        subprocess.run(cmd)
82
83    def add_file(path, name):
84        file = os.path.join(path, name)
85        with open(file, "w") as fp:
86            pass
87        cmd = ["git", "add", file]
88        subprocess.run(cmd)
89
90    covered_unexpected_successes = set()
91
92    # dynamo_expected_failures
93    for test in dynamo_expected_failures:
94        if test in formatted_unexpected_successes:
95            covered_unexpected_successes.add(test)
96            remove_file(failures_directory, test)
97    for test in formatted_new_xfails:
98        add_file(failures_directory, test)
99
100    leftover_unexpected_successes = (
101        formatted_unexpected_successes - covered_unexpected_successes
102    )
103    if len(leftover_unexpected_successes) > 0:
104        print(
105            "WARNING: we were unable to remove these "
106            f"{len(leftover_unexpected_successes)} expectedFailures:"
107        )
108        for stuff in leftover_unexpected_successes:
109            print(stuff)
110
111    # dynamo_skips
112    for test in dynamo_skips:
113        if test in formatted_unexpected_skips:
114            remove_file(skips_directory, test)
115    for test in extra_dynamo_skips:
116        if test in formatted_unexpected_skips:
117            print(
118                f"WARNING: {test} in dynamo_test_failures.py needs to be removed manually"
119            )
120    for test in formatted_new_skips:
121        add_file(skips_directory, test)
122
123
124def get_intersection_and_outside(a_dict, b_dict):
125    a = set(a_dict.keys())
126    b = set(b_dict.keys())
127    intersection = a.intersection(b)
128    outside = (a.union(b)) - intersection
129
130    def build_dict(keys):
131        result = {}
132        for k in keys:
133            if k in a_dict:
134                result[k] = a_dict[k]
135            else:
136                result[k] = b_dict[k]
137        return result
138
139    return build_dict(intersection), build_dict(outside)
140
141
142def update(filename, test_dir, py38_dir, py311_dir, also_remove_skips):
143    def read_test_results(directory):
144        xmls = open_test_results(directory)
145        testcases = get_testcases(xmls)
146        unexpected_successes = {
147            key(test): test for test in testcases if is_unexpected_success(test)
148        }
149        failures = {key(test): test for test in testcases if is_failure(test)}
150        passing_skipped_tests = {
151            key(test): test for test in testcases if is_passing_skipped_test(test)
152        }
153        return unexpected_successes, failures, passing_skipped_tests
154
155    (
156        py38_unexpected_successes,
157        py38_failures,
158        py38_passing_skipped_tests,
159    ) = read_test_results(py38_dir)
160    (
161        py311_unexpected_successes,
162        py311_failures,
163        py311_passing_skipped_tests,
164    ) = read_test_results(py311_dir)
165
166    unexpected_successes = {**py38_unexpected_successes, **py311_unexpected_successes}
167    _, skips = get_intersection_and_outside(
168        py38_unexpected_successes, py311_unexpected_successes
169    )
170    xfails, more_skips = get_intersection_and_outside(py38_failures, py311_failures)
171    if also_remove_skips:
172        unexpected_skips, _ = get_intersection_and_outside(
173            py38_passing_skipped_tests, py311_passing_skipped_tests
174        )
175    else:
176        unexpected_skips = {}
177    all_skips = {**skips, **more_skips}
178    print(
179        f"Discovered {len(unexpected_successes)} new unexpected successes, "
180        f"{len(xfails)} new xfails, {len(all_skips)} new skips, {len(unexpected_skips)} new unexpected skips"
181    )
182    return patch_file(
183        filename, test_dir, unexpected_successes, xfails, all_skips, unexpected_skips
184    )
185
186
187if __name__ == "__main__":
188    parser = argparse.ArgumentParser(
189        prog="update_dynamo_test_failures",
190        description="Read from logs and update the dynamo_test_failures file",
191    )
192    # dynamo_test_failures path
193    parser.add_argument(
194        "filename",
195        nargs="?",
196        default=str(
197            Path(__file__).absolute().parent.parent.parent
198            / "torch/testing/_internal/dynamo_test_failures.py"
199        ),
200        help="Optional path to dynamo_test_failures.py",
201    )
202    # test path
203    parser.add_argument(
204        "test_dir",
205        nargs="?",
206        default=str(Path(__file__).absolute().parent.parent.parent / "test"),
207        help="Optional path to test folder",
208    )
209    parser.add_argument(
210        "commit",
211        help=(
212            "The commit sha for the latest commit on a PR from which we will "
213            "pull CI test results, e.g. 7e5f597aeeba30c390c05f7d316829b3798064a5"
214        ),
215    )
216    parser.add_argument(
217        "--also-remove-skips",
218        help="Also attempt to remove skips. WARNING: does not guard against test flakiness",
219        action="store_true",
220    )
221    args = parser.parse_args()
222    assert Path(args.filename).exists(), args.filename
223    assert Path(args.test_dir).exists(), args.test_dir
224    dynamo39, dynamo311 = download_reports(args.commit, ("dynamo39", "dynamo311"))
225    update(args.filename, args.test_dir, dynamo39, dynamo311, args.also_remove_skips)
226