xref: /aosp_15_r20/external/pytorch/scripts/compile_tests/passrate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2
3from common import (
4    get_excluded_testcases,
5    get_passed_testcases,
6    get_testcases,
7    key,
8    open_test_results,
9)
10from download_reports import download_reports
11
12
13"""
14Usage: passrate.py commit_sha
15
16Parses test reports to measure the passrate. The passrate is defined as:
17
18A) Take the number of tests that pass under eager mode, excluding
19CUDA, OpInfo, and ModuleInfo tests
20B) Of those tests, count the number of tests that pass under Dynamo
21C) Take B/A.
22
23You'll need to provide the commit_sha for a commit on the main branch,
24from which we will pull CI test results.
25
26This script requires the `gh` cli. You'll need to install it and then
27authenticate with it via `gh auth login` before using this script.
28https://docs.github.com/en/github-cli/github-cli/quickstart
29"""
30
31
32def testcases_by_time(xmls):
33    testcases = get_testcases(xmls)
34    testcases.sort(reverse=True, key=lambda x: float(x.attrib["time"]))
35    return testcases
36
37
38def should_exclude(key):
39    test_file = key.split("::")[0]
40    # C++ tests
41    if test_file == "UNKNOWN":
42        return True
43    # Policy: "pass rate" does not include inductor, export, or dynamo tests.
44    return test_file.startswith(("inductor/", "export/", "dynamo/"))
45
46
47def compute_pass_rate(eager_dir, dynamo_dir):
48    print("parsing xmls")
49    eager_xmls = open_test_results(eager_dir)
50    dynamo_xmls = open_test_results(dynamo_dir)
51
52    print("computing pass rate")
53    eager_passed = get_passed_testcases(eager_xmls)
54    dynamo_passed = get_passed_testcases(dynamo_xmls)
55    dynamo_pass_keys = {key(testcase) for testcase in dynamo_passed}
56    dynamo_pass_keys = {key_ for key_ in dynamo_pass_keys if not should_exclude(key_)}
57    tmp_eager_pass_keys = {key(testcase) for testcase in eager_passed}
58    tmp_eager_pass_keys = {
59        key_ for key_ in tmp_eager_pass_keys if not should_exclude(key_)
60    }
61    excluded = [key(t) for t in get_excluded_testcases(dynamo_xmls)]
62    eager_pass_keys = tmp_eager_pass_keys - set(excluded)
63
64    subset = eager_pass_keys.intersection(dynamo_pass_keys)
65    total_subset = len(subset)
66    total_tests = len(eager_pass_keys)
67    print("pass rate", total_subset / total_tests, total_subset, total_tests)
68
69    dynamo_testcases = get_testcases(dynamo_xmls)
70    tc = {key(t): t for t in dynamo_testcases}
71
72    # Useful for debugging
73    not_there_keys = set()
74    for key_ in eager_pass_keys:
75        if key_ not in tc:
76            not_there_keys.add(key_)
77
78    fail_keys = eager_pass_keys - subset
79    return fail_keys
80
81
82if __name__ == "__main__":
83    parser = argparse.ArgumentParser(
84        prog="passrate", description="Computes the Dynamo unittest pass rate"
85    )
86    parser.add_argument(
87        "commit",
88        help=(
89            "The commit sha for the latest commit on a PR from which we will "
90            "pull CI test results, e.g. 7e5f597aeeba30c390c05f7d316829b3798064a5"
91        ),
92    )
93    args = parser.parse_args()
94    dynamo311, eager311 = download_reports(args.commit, ("dynamo311", "eager311"))
95    compute_pass_rate(eager311, dynamo311)
96