1from __future__ import annotations 2 3import argparse 4import json 5import os 6import xml.etree.ElementTree as ET 7from pathlib import Path 8from tempfile import TemporaryDirectory 9from typing import Any, Generator 10 11from tools.stats.upload_stats_lib import ( 12 download_s3_artifacts, 13 is_rerun_disabled_tests, 14 unzip, 15 upload_workflow_stats_to_s3, 16) 17from tools.stats.upload_test_stats import process_xml_element 18 19 20TESTCASE_TAG = "testcase" 21SEPARATOR = ";" 22 23 24def process_report( 25 report: Path, 26) -> dict[str, dict[str, int]]: 27 """ 28 Return a list of disabled tests that should be re-enabled and those that are still 29 flaky (failed or skipped) 30 """ 31 root = ET.parse(report) 32 33 # All rerun tests from a report are grouped here: 34 # 35 # * Success test should be re-enable if it's green after rerunning in all platforms 36 # where it is currently disabled 37 # * Failures from pytest because pytest-flakefinder is used to run the same test 38 # multiple times, some could fails 39 # * Skipped tests from unittest 40 # 41 # We want to keep track of how many times the test fails (num_red) or passes (num_green) 42 all_tests: dict[str, dict[str, int]] = {} 43 44 for test_case in root.iter(TESTCASE_TAG): 45 parsed_test_case = process_xml_element(test_case) 46 47 # Under --rerun-disabled-tests mode, a test is skipped when: 48 # * it's skipped explicitly inside PyTorch code 49 # * it's skipped because it's a normal enabled test 50 # * or it's falky (num_red > 0 and num_green > 0) 51 # * or it's failing (num_red > 0 and num_green == 0) 52 # 53 # We care only about the latter two here 54 skipped = parsed_test_case.get("skipped", None) 55 56 # NB: Regular ONNX tests could return a list of subskips here where each item in the 57 # list is a skipped message. In the context of rerunning disabled tests, we could 58 # ignore this case as returning a list of subskips only happens when tests are run 59 # normally 60 if skipped and ( 61 type(skipped) is list or "num_red" not in skipped.get("message", "") 62 ): 63 continue 64 65 name = parsed_test_case.get("name", "") 66 classname = parsed_test_case.get("classname", "") 67 filename = parsed_test_case.get("file", "") 68 69 if not name or not classname or not filename: 70 continue 71 72 # Check if the test is a failure 73 failure = parsed_test_case.get("failure", None) 74 75 disabled_test_id = SEPARATOR.join([name, classname, filename]) 76 if disabled_test_id not in all_tests: 77 all_tests[disabled_test_id] = { 78 "num_green": 0, 79 "num_red": 0, 80 } 81 82 # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's 83 # counted as a success. Otherwise, it's still flaky or failing 84 if skipped: 85 try: 86 stats = json.loads(skipped.get("message", "")) 87 except json.JSONDecodeError: 88 stats = {} 89 90 all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0) 91 all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0) 92 elif failure: 93 # As a failure, increase the failure count 94 all_tests[disabled_test_id]["num_red"] += 1 95 else: 96 all_tests[disabled_test_id]["num_green"] += 1 97 98 return all_tests 99 100 101def get_test_reports( 102 repo: str, workflow_run_id: int, workflow_run_attempt: int 103) -> Generator[Path, None, None]: 104 """ 105 Gather all the test reports from S3 and GHA. It is currently not possible to guess which 106 test reports are from rerun_disabled_tests workflow because the name doesn't include the 107 test config. So, all reports will need to be downloaded and examined 108 """ 109 with TemporaryDirectory() as temp_dir: 110 print("Using temporary directory:", temp_dir) 111 os.chdir(temp_dir) 112 113 artifact_paths = download_s3_artifacts( 114 "test-reports", workflow_run_id, workflow_run_attempt 115 ) 116 for path in artifact_paths: 117 unzip(path) 118 119 yield from Path(".").glob("**/*.xml") 120 121 122def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]: 123 """ 124 Follow flaky bot convention here, if that changes, this will also need to be updated 125 """ 126 name, classname, filename = test_id.split(SEPARATOR) 127 return f"{name} (__main__.{classname})", name, classname, filename 128 129 130def prepare_record( 131 workflow_id: int, 132 workflow_run_attempt: int, 133 name: str, 134 classname: str, 135 filename: str, 136 flaky: bool, 137 num_red: int = 0, 138 num_green: int = 0, 139) -> tuple[Any, dict[str, Any]]: 140 """ 141 Prepare the record to save onto S3 142 """ 143 key = ( 144 workflow_id, 145 workflow_run_attempt, 146 name, 147 classname, 148 filename, 149 ) 150 151 record = { 152 "workflow_id": workflow_id, 153 "workflow_run_attempt": workflow_run_attempt, 154 "name": name, 155 "classname": classname, 156 "filename": filename, 157 "flaky": flaky, 158 "num_green": num_green, 159 "num_red": num_red, 160 } 161 162 return key, record 163 164 165def save_results( 166 workflow_id: int, 167 workflow_run_attempt: int, 168 all_tests: dict[str, dict[str, int]], 169) -> None: 170 """ 171 Save the result to S3, so it can go to Rockset 172 """ 173 should_be_enabled_tests = { 174 name: stats 175 for name, stats in all_tests.items() 176 if "num_green" in stats 177 and stats["num_green"] 178 and "num_red" in stats 179 and stats["num_red"] == 0 180 } 181 still_flaky_tests = { 182 name: stats 183 for name, stats in all_tests.items() 184 if name not in should_be_enabled_tests 185 } 186 187 records = {} 188 for test_id, stats in all_tests.items(): 189 num_green = stats.get("num_green", 0) 190 num_red = stats.get("num_red", 0) 191 disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) 192 193 key, record = prepare_record( 194 workflow_id=workflow_id, 195 workflow_run_attempt=workflow_run_attempt, 196 name=name, 197 classname=classname, 198 filename=filename, 199 flaky=test_id in still_flaky_tests, 200 num_green=num_green, 201 num_red=num_red, 202 ) 203 records[key] = record 204 205 # Log the results 206 print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:") 207 for test_id, stats in should_be_enabled_tests.items(): 208 disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) 209 print(f" {disabled_test_name} from {filename}") 210 211 print(f"The following {len(still_flaky_tests)} are still flaky:") 212 for test_id, stats in still_flaky_tests.items(): 213 num_green = stats.get("num_green", 0) 214 num_red = stats.get("num_red", 0) 215 216 disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) 217 print( 218 f" {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}" 219 ) 220 221 upload_workflow_stats_to_s3( 222 workflow_id, 223 workflow_run_attempt, 224 "rerun_disabled_tests", 225 list(records.values()), 226 ) 227 228 229def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None: 230 """ 231 Find the list of all disabled tests that should be re-enabled 232 """ 233 # Aggregated across all jobs 234 all_tests: dict[str, dict[str, int]] = {} 235 236 for report in get_test_reports( 237 args.repo, args.workflow_run_id, args.workflow_run_attempt 238 ): 239 tests = process_report(report) 240 241 # The scheduled workflow has both rerun disabled tests and memory leak check jobs. 242 # We are only interested in the former here 243 if not is_rerun_disabled_tests(tests): 244 continue 245 246 for name, stats in tests.items(): 247 if name not in all_tests: 248 all_tests[name] = stats.copy() 249 else: 250 all_tests[name]["num_green"] += stats.get("num_green", 0) 251 all_tests[name]["num_red"] += stats.get("num_red", 0) 252 253 save_results( 254 workflow_run_id, 255 workflow_run_attempt, 256 all_tests, 257 ) 258 259 260if __name__ == "__main__": 261 parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3") 262 parser.add_argument( 263 "--workflow-run-id", 264 type=int, 265 required=True, 266 help="id of the workflow to get artifacts from", 267 ) 268 parser.add_argument( 269 "--workflow-run-attempt", 270 type=int, 271 required=True, 272 help="which retry of the workflow this is", 273 ) 274 parser.add_argument( 275 "--repo", 276 type=str, 277 required=True, 278 help="which GitHub repo this workflow run belongs to", 279 ) 280 281 args = parser.parse_args() 282 main(args.repo, args.workflow_run_id, args.workflow_run_attempt) 283