xref: /aosp_15_r20/external/pytorch/tools/render_junit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import os
7from typing import Any
8
9
10try:
11    from junitparser import (  # type: ignore[import]
12        Error,
13        Failure,
14        JUnitXml,
15        TestCase,
16        TestSuite,
17    )
18except ImportError as e:
19    raise ImportError(
20        "junitparser not found, please install with 'pip install junitparser'"
21    ) from e
22
23try:
24    import rich
25except ImportError:
26    print("rich not found, for color output use 'pip install rich'")
27
28
29def parse_junit_reports(path_to_reports: str) -> list[TestCase]:  # type: ignore[no-any-unimported]
30    def parse_file(path: str) -> list[TestCase]:  # type: ignore[no-any-unimported]
31        try:
32            return convert_junit_to_testcases(JUnitXml.fromfile(path))
33        except Exception as err:
34            rich.print(
35                f":Warning: [yellow]Warning[/yellow]: Failed to read {path}: {err}"
36            )
37            return []
38
39    if not os.path.exists(path_to_reports):
40        raise FileNotFoundError(f"Path '{path_to_reports}', not found")
41    # Return early if the path provided is just a file
42    if os.path.isfile(path_to_reports):
43        return parse_file(path_to_reports)
44    ret_xml = []
45    if os.path.isdir(path_to_reports):
46        for root, _, files in os.walk(path_to_reports):
47            for fname in [f for f in files if f.endswith("xml")]:
48                ret_xml += parse_file(os.path.join(root, fname))
49    return ret_xml
50
51
52def convert_junit_to_testcases(xml: JUnitXml | TestSuite) -> list[TestCase]:  # type: ignore[no-any-unimported]
53    testcases = []
54    for item in xml:
55        if isinstance(item, TestSuite):
56            testcases.extend(convert_junit_to_testcases(item))
57        else:
58            testcases.append(item)
59    return testcases
60
61
62def render_tests(testcases: list[TestCase]) -> None:  # type: ignore[no-any-unimported]
63    num_passed = 0
64    num_skipped = 0
65    num_failed = 0
66    for testcase in testcases:
67        if not testcase.result:
68            num_passed += 1
69            continue
70        for result in testcase.result:
71            if isinstance(result, Error):
72                icon = ":rotating_light: [white on red]ERROR[/white on red]:"
73                num_failed += 1
74            elif isinstance(result, Failure):
75                icon = ":x: [white on red]Failure[/white on red]:"
76                num_failed += 1
77            else:
78                num_skipped += 1
79                continue
80            rich.print(
81                f"{icon} [bold red]{testcase.classname}.{testcase.name}[/bold red]"
82            )
83            print(f"{result.text}")
84    rich.print(f":white_check_mark: {num_passed} [green]Passed[green]")
85    rich.print(f":dash: {num_skipped} [grey]Skipped[grey]")
86    rich.print(f":rotating_light: {num_failed} [grey]Failed[grey]")
87
88
89def parse_args() -> Any:
90    parser = argparse.ArgumentParser(
91        description="Render xunit output for failed tests",
92    )
93    parser.add_argument(
94        "report_path",
95        help="Base xunit reports (single file or directory) to compare to",
96    )
97    return parser.parse_args()
98
99
100def main() -> None:
101    options = parse_args()
102    testcases = parse_junit_reports(options.report_path)
103    render_tests(testcases)
104
105
106if __name__ == "__main__":
107    main()
108