xref: /aosp_15_r20/external/pytorch/tools/stats/upload_test_stats.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import os
5import sys
6import xml.etree.ElementTree as ET
7from multiprocessing import cpu_count, Pool
8from pathlib import Path
9from tempfile import TemporaryDirectory
10from typing import Any
11
12from tools.stats.test_dashboard import upload_additional_info
13from tools.stats.upload_stats_lib import (
14    download_s3_artifacts,
15    get_job_id,
16    unzip,
17    upload_workflow_stats_to_s3,
18)
19
20
21def parse_xml_report(
22    tag: str,
23    report: Path,
24    workflow_id: int,
25    workflow_run_attempt: int,
26) -> list[dict[str, Any]]:
27    """Convert a test report xml file into a JSON-serializable list of test cases."""
28    print(f"Parsing {tag}s for test report: {report}")
29
30    job_id = get_job_id(report)
31    print(f"Found job id: {job_id}")
32
33    test_cases: list[dict[str, Any]] = []
34
35    root = ET.parse(report)
36    for test_case in root.iter(tag):
37        case = process_xml_element(test_case)
38        case["workflow_id"] = workflow_id
39        case["workflow_run_attempt"] = workflow_run_attempt
40        case["job_id"] = job_id
41
42        # [invoking file]
43        # The name of the file that the test is located in is not necessarily
44        # the same as the name of the file that invoked the test.
45        # For example, `test_jit.py` calls into multiple other test files (e.g.
46        # jit/test_dce.py). For sharding/test selection purposes, we want to
47        # record the file that invoked the test.
48        #
49        # To do this, we leverage an implementation detail of how we write out
50        # tests (https://bit.ly/3ajEV1M), which is that reports are created
51        # under a folder with the same name as the invoking file.
52        case["invoking_file"] = report.parent.name
53        test_cases.append(case)
54
55    return test_cases
56
57
58def process_xml_element(element: ET.Element) -> dict[str, Any]:
59    """Convert a test suite element into a JSON-serializable dict."""
60    ret: dict[str, Any] = {}
61
62    # Convert attributes directly into dict elements.
63    # e.g.
64    #     <testcase name="test_foo" classname="test_bar"></testcase>
65    # becomes:
66    #     {"name": "test_foo", "classname": "test_bar"}
67    ret.update(element.attrib)
68
69    # The XML format encodes all values as strings. Convert to ints/floats if
70    # possible to make aggregation possible in Rockset.
71    for k, v in ret.items():
72        try:
73            ret[k] = int(v)
74        except ValueError:
75            pass
76        try:
77            ret[k] = float(v)
78        except ValueError:
79            pass
80
81    # Convert inner and outer text into special dict elements.
82    # e.g.
83    #     <testcase>my_inner_text</testcase> my_tail
84    # becomes:
85    #     {"text": "my_inner_text", "tail": " my_tail"}
86    if element.text and element.text.strip():
87        ret["text"] = element.text
88    if element.tail and element.tail.strip():
89        ret["tail"] = element.tail
90
91    # Convert child elements recursively, placing them at a key:
92    # e.g.
93    #     <testcase>
94    #       <foo>hello</foo>
95    #       <foo>world</foo>
96    #       <bar>another</bar>
97    #     </testcase>
98    # becomes
99    #    {
100    #       "foo": [{"text": "hello"}, {"text": "world"}],
101    #       "bar": {"text": "another"}
102    #    }
103    for child in element:
104        if child.tag not in ret:
105            ret[child.tag] = process_xml_element(child)
106        else:
107            # If there are multiple tags with the same name, they should be
108            # coalesced into a list.
109            if not isinstance(ret[child.tag], list):
110                ret[child.tag] = [ret[child.tag]]
111            ret[child.tag].append(process_xml_element(child))
112    return ret
113
114
115def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, Any]]:
116    with TemporaryDirectory() as temp_dir:
117        print("Using temporary directory:", temp_dir)
118        os.chdir(temp_dir)
119
120        # Download and extract all the reports (both GHA and S3)
121        s3_paths = download_s3_artifacts(
122            "test-report", workflow_run_id, workflow_run_attempt
123        )
124        for path in s3_paths:
125            unzip(path)
126
127        # Parse the reports and transform them to JSON
128        test_cases = []
129        mp = Pool(cpu_count())
130        for xml_report in Path(".").glob("**/*.xml"):
131            test_cases.append(
132                mp.apply_async(
133                    parse_xml_report,
134                    args=(
135                        "testcase",
136                        xml_report,
137                        workflow_run_id,
138                        workflow_run_attempt,
139                    ),
140                )
141            )
142        mp.close()
143        mp.join()
144        test_cases = [tc.get() for tc in test_cases]
145        flattened = [item for sublist in test_cases for item in sublist]
146        return flattened
147
148
149def get_tests_for_circleci(
150    workflow_run_id: int, workflow_run_attempt: int
151) -> list[dict[str, Any]]:
152    # Parse the reports and transform them to JSON
153    test_cases = []
154    for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
155        test_cases.extend(
156            parse_xml_report(
157                "testcase", xml_report, workflow_run_id, workflow_run_attempt
158            )
159        )
160
161    return test_cases
162
163
164def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
165    """Group test cases by classname, file, and job_id. We perform the aggregation
166    manually instead of using the `test-suite` XML tag because xmlrunner does
167    not produce reliable output for it.
168    """
169
170    def get_key(test_case: dict[str, Any]) -> Any:
171        return (
172            test_case.get("file"),
173            test_case.get("classname"),
174            test_case["job_id"],
175            test_case["workflow_id"],
176            test_case["workflow_run_attempt"],
177            # [see: invoking file]
178            test_case["invoking_file"],
179        )
180
181    def init_value(test_case: dict[str, Any]) -> dict[str, Any]:
182        return {
183            "file": test_case.get("file"),
184            "classname": test_case.get("classname"),
185            "job_id": test_case["job_id"],
186            "workflow_id": test_case["workflow_id"],
187            "workflow_run_attempt": test_case["workflow_run_attempt"],
188            # [see: invoking file]
189            "invoking_file": test_case["invoking_file"],
190            "tests": 0,
191            "failures": 0,
192            "errors": 0,
193            "skipped": 0,
194            "successes": 0,
195            "time": 0.0,
196        }
197
198    ret = {}
199    for test_case in test_cases:
200        key = get_key(test_case)
201        if key not in ret:
202            ret[key] = init_value(test_case)
203
204        ret[key]["tests"] += 1
205
206        if "failure" in test_case:
207            ret[key]["failures"] += 1
208        elif "error" in test_case:
209            ret[key]["errors"] += 1
210        elif "skipped" in test_case:
211            ret[key]["skipped"] += 1
212        else:
213            ret[key]["successes"] += 1
214
215        ret[key]["time"] += test_case["time"]
216    return list(ret.values())
217
218
219if __name__ == "__main__":
220    parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
221    parser.add_argument(
222        "--workflow-run-id",
223        required=True,
224        help="id of the workflow to get artifacts from",
225    )
226    parser.add_argument(
227        "--workflow-run-attempt",
228        type=int,
229        required=True,
230        help="which retry of the workflow this is",
231    )
232    parser.add_argument(
233        "--head-branch",
234        required=True,
235        help="Head branch of the workflow",
236    )
237    parser.add_argument(
238        "--head-repository",
239        required=True,
240        help="Head repository of the workflow",
241    )
242    parser.add_argument(
243        "--circleci",
244        action="store_true",
245        help="If this is being run through circleci",
246    )
247    args = parser.parse_args()
248
249    print(f"Workflow id is: {args.workflow_run_id}")
250
251    if args.circleci:
252        test_cases = get_tests_for_circleci(
253            args.workflow_run_id, args.workflow_run_attempt
254        )
255    else:
256        test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
257
258    # Flush stdout so that any errors in Rockset upload show up last in the logs.
259    sys.stdout.flush()
260
261    # For PRs, only upload a summary of test_runs. This helps lower the
262    # volume of writes we do to Rockset.
263    test_case_summary = summarize_test_cases(test_cases)
264
265    upload_workflow_stats_to_s3(
266        args.workflow_run_id,
267        args.workflow_run_attempt,
268        "test_run_summary",
269        test_case_summary,
270    )
271
272    # Separate out the failed test cases.
273    # Uploading everything is too data intensive most of the time,
274    # but these will be just a tiny fraction.
275    failed_tests_cases = []
276    for test_case in test_cases:
277        if "rerun" in test_case or "failure" in test_case or "error" in test_case:
278            failed_tests_cases.append(test_case)
279
280    upload_workflow_stats_to_s3(
281        args.workflow_run_id,
282        args.workflow_run_attempt,
283        "failed_test_runs",
284        failed_tests_cases,
285    )
286
287    if args.head_branch == "main" and args.head_repository == "pytorch/pytorch":
288        # For jobs on main branch, upload everything.
289        upload_workflow_stats_to_s3(
290            args.workflow_run_id, args.workflow_run_attempt, "test_run", test_cases
291        )
292
293    upload_additional_info(args.workflow_run_id, args.workflow_run_attempt, test_cases)
294