xref: /aosp_15_r20/external/pytorch/tools/stats/test_dashboard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import json
4import os
5import re
6import time
7from collections import defaultdict
8from functools import lru_cache
9from pathlib import Path
10from tempfile import TemporaryDirectory
11from typing import Any, cast
12
13import requests
14
15from tools.stats.upload_stats_lib import (
16    _get_request_headers,
17    download_s3_artifacts,
18    get_job_id,
19    unzip,
20    upload_workflow_stats_to_s3,
21)
22
23
24REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)"
25
26
27@lru_cache(maxsize=1000)
28def get_job_name(job_id: int) -> str:
29    try:
30        return cast(
31            str,
32            requests.get(
33                f"https://api.github.com/repos/pytorch/pytorch/actions/jobs/{job_id}",
34                headers=_get_request_headers(),
35            ).json()["name"],
36        )
37    except Exception as e:
38        print(f"Failed to get job name for job id {job_id}: {e}")
39        return "NoJobName"
40
41
42@lru_cache(maxsize=1000)
43def get_build_name(job_name: str) -> str:
44    try:
45        return re.match(REGEX_JOB_INFO, job_name).group(1)  # type: ignore[union-attr]
46    except AttributeError:
47        print(f"Failed to match job name: {job_name}")
48        return "NoBuildEnv"
49
50
51@lru_cache(maxsize=1000)
52def get_test_config(job_name: str) -> str:
53    try:
54        return re.match(REGEX_JOB_INFO, job_name).group(2)  # type: ignore[union-attr]
55    except AttributeError:
56        print(f"Failed to match job name: {job_name}")
57        return "NoTestConfig"
58
59
60def get_td_exclusions(
61    workflow_run_id: int, workflow_run_attempt: int
62) -> dict[str, Any]:
63    with TemporaryDirectory() as temp_dir:
64        print("Using temporary directory:", temp_dir)
65        os.chdir(temp_dir)
66
67        # Download and extract all the reports (both GHA and S3)
68        s3_paths = download_s3_artifacts(
69            "test-jsons", workflow_run_id, workflow_run_attempt
70        )
71        for path in s3_paths:
72            unzip(path)
73
74        grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set))
75        for td_exclusions in Path(".").glob("**/td_exclusions*.json"):
76            with open(td_exclusions) as f:
77                exclusions = json.load(f)
78                for exclusion in exclusions["excluded"]:
79                    job_id = get_job_id(td_exclusions)
80                    job_name = get_job_name(job_id)
81                    build_name = get_build_name(job_name)
82                    test_config = get_test_config(job_name)
83                    grouped_tests[build_name][test_config].add(exclusion["test_file"])
84
85        for build_name, build in grouped_tests.items():
86            for test_config, test_files in build.items():
87                grouped_tests[build_name][test_config] = sorted(test_files)
88        return grouped_tests
89
90
91def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]:
92    start = time.time()
93    grouped_tests: dict[str, Any] = defaultdict(
94        lambda: defaultdict(
95            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
96        )
97    )
98    for test_case in test_cases:
99        job_name = get_job_name(test_case["job_id"])
100        build_name = get_build_name(job_name)
101        if "bazel" in build_name:
102            continue
103        test_config = get_test_config(job_name)
104        class_name = test_case.pop("classname", "NoClass")
105        name = test_case.pop("name", "NoName")
106        invoking_file = test_case.pop("invoking_file", "NoFile")
107        invoking_file = invoking_file.replace(".", "/")
108        test_case.pop("workflow_id")
109        test_case.pop("workflow_run_attempt")
110        grouped_tests[build_name][test_config][invoking_file][class_name][name].append(
111            test_case
112        )
113
114    print(f"Time taken to group tests: {time.time() - start}")
115    return grouped_tests
116
117
118def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]:
119    reruns: dict[str, Any] = defaultdict(
120        lambda: defaultdict(
121            lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
122        )
123    )
124    for build_name, build in grouped_tests.items():
125        for test_config, test_config_data in build.items():
126            for invoking_file, invoking_file_data in test_config_data.items():
127                for class_name, class_data in invoking_file_data.items():
128                    for test_name, test_data in class_data.items():
129                        if len(test_data) > 1:
130                            if invoking_file in (
131                                "distributed/test_distributed_spawn",
132                                "onnx/test_fx_to_onnx_with_onnxruntime",
133                                "distributed/algorithms/quantization/test_quantization",
134                            ):
135                                continue
136                            reruns[build_name][test_config][invoking_file][class_name][
137                                test_name
138                            ] = test_data
139    return reruns
140
141
142def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]:
143    invoking_file_summary: dict[str, Any] = defaultdict(
144        lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0}))
145    )
146    for build_name, build in grouped_tests.items():
147        for test_config, test_config_data in build.items():
148            for invoking_file, invoking_file_data in test_config_data.items():
149                for class_data in invoking_file_data.values():
150                    for test_data in class_data.values():
151                        invoking_file_summary[build_name][test_config][invoking_file][
152                            "count"
153                        ] += 1
154                        for i in test_data:
155                            invoking_file_summary[build_name][test_config][
156                                invoking_file
157                            ]["time"] += i["time"]
158
159    return invoking_file_summary
160
161
162def upload_additional_info(
163    workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]]
164) -> None:
165    grouped_tests = group_test_cases(test_cases)
166    reruns = get_reruns(grouped_tests)
167    exclusions = get_td_exclusions(workflow_run_id, workflow_run_attempt)
168    invoking_file_summary = get_invoking_file_summary(grouped_tests)
169
170    upload_workflow_stats_to_s3(
171        workflow_run_id,
172        workflow_run_attempt,
173        "additional_info/reruns",
174        [reruns],
175    )
176    upload_workflow_stats_to_s3(
177        workflow_run_id,
178        workflow_run_attempt,
179        "additional_info/td_exclusions",
180        [exclusions],
181    )
182    upload_workflow_stats_to_s3(
183        workflow_run_id,
184        workflow_run_attempt,
185        "additional_info/invoking_file_summary",
186        [invoking_file_summary],
187    )
188