xref: /aosp_15_r20/external/pytorch/scripts/compile_tests/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import os
3import warnings
4
5
6try:
7    import lxml.etree
8
9    p = lxml.etree.XMLParser(huge_tree=True)
10    parse = functools.partial(lxml.etree.parse, parser=p)
11except ImportError:
12    import xml.etree.ElementTree as ET
13
14    parse = ET.parse
15    warnings.warn(
16        "lxml was not found. `pip install lxml` to make this script run much faster"
17    )
18
19
20def open_test_results(directory):
21    xmls = []
22    for root, _, files in os.walk(directory):
23        for file in files:
24            if file.endswith(".xml"):
25                tree = parse(f"{root}/{file}")
26                xmls.append(tree)
27    return xmls
28
29
30def get_testcases(xmls):
31    testcases = []
32    for xml in xmls:
33        root = xml.getroot()
34        testcases.extend(list(root.iter("testcase")))
35    return testcases
36
37
38def find(testcase, condition):
39    children = list(testcase.iter())
40    assert children[0] is testcase
41    children = children[1:]
42    return condition(children)
43
44
45def skipped_test(testcase):
46    def condition(children):
47        return "skipped" in {child.tag for child in children}
48
49    return find(testcase, condition)
50
51
52def passed_test(testcase):
53    def condition(children):
54        if len(children) == 0:
55            return True
56        tags = {child.tag for child in children}
57        return "skipped" not in tags and "failed" not in tags
58
59    return find(testcase, condition)
60
61
62def key(testcase):
63    file = testcase.attrib.get("file", "UNKNOWN")
64    classname = testcase.attrib["classname"]
65    name = testcase.attrib["name"]
66    return "::".join([file, classname, name])
67
68
69def get_passed_testcases(xmls):
70    testcases = get_testcases(xmls)
71    passed_testcases = [testcase for testcase in testcases if passed_test(testcase)]
72    return passed_testcases
73
74
75def get_excluded_testcases(xmls):
76    testcases = get_testcases(xmls)
77    excluded_testcases = [t for t in testcases if excluded_testcase(t)]
78    return excluded_testcases
79
80
81def excluded_testcase(testcase):
82    def condition(children):
83        for child in children:
84            if child.tag == "skipped":
85                if "Policy: we don't run" in child.attrib["message"]:
86                    return True
87        return False
88
89    return find(testcase, condition)
90
91
92def is_unexpected_success(testcase):
93    def condition(children):
94        for child in children:
95            if child.tag != "failure":
96                continue
97            is_unexpected_success = (
98                "unexpected success" in child.attrib["message"].lower()
99            )
100            if is_unexpected_success:
101                return True
102        return False
103
104    return find(testcase, condition)
105
106
107MSG = "This test passed, maybe we can remove the skip from dynamo_test_failures.py"
108
109
110def is_passing_skipped_test(testcase):
111    def condition(children):
112        for child in children:
113            if child.tag != "skipped":
114                continue
115            has_passing_skipped_test_msg = MSG in child.attrib["message"]
116            if has_passing_skipped_test_msg:
117                return True
118        return False
119
120    return find(testcase, condition)
121
122
123# NB: not an unexpected success
124def is_failure(testcase):
125    def condition(children):
126        for child in children:
127            if child.tag != "failure":
128                continue
129            is_unexpected_success = (
130                "unexpected success" in child.attrib["message"].lower()
131            )
132            if not is_unexpected_success:
133                return True
134        return False
135
136    return find(testcase, condition)
137