1from __future__ import annotations 2 3import json 4import os 5import re 6import time 7from collections import defaultdict 8from functools import lru_cache 9from pathlib import Path 10from tempfile import TemporaryDirectory 11from typing import Any, cast 12 13import requests 14 15from tools.stats.upload_stats_lib import ( 16 _get_request_headers, 17 download_s3_artifacts, 18 get_job_id, 19 unzip, 20 upload_workflow_stats_to_s3, 21) 22 23 24REGEX_JOB_INFO = r"(.*) \/ .*test \(([^,]*), .*\)" 25 26 27@lru_cache(maxsize=1000) 28def get_job_name(job_id: int) -> str: 29 try: 30 return cast( 31 str, 32 requests.get( 33 f"https://api.github.com/repos/pytorch/pytorch/actions/jobs/{job_id}", 34 headers=_get_request_headers(), 35 ).json()["name"], 36 ) 37 except Exception as e: 38 print(f"Failed to get job name for job id {job_id}: {e}") 39 return "NoJobName" 40 41 42@lru_cache(maxsize=1000) 43def get_build_name(job_name: str) -> str: 44 try: 45 return re.match(REGEX_JOB_INFO, job_name).group(1) # type: ignore[union-attr] 46 except AttributeError: 47 print(f"Failed to match job name: {job_name}") 48 return "NoBuildEnv" 49 50 51@lru_cache(maxsize=1000) 52def get_test_config(job_name: str) -> str: 53 try: 54 return re.match(REGEX_JOB_INFO, job_name).group(2) # type: ignore[union-attr] 55 except AttributeError: 56 print(f"Failed to match job name: {job_name}") 57 return "NoTestConfig" 58 59 60def get_td_exclusions( 61 workflow_run_id: int, workflow_run_attempt: int 62) -> dict[str, Any]: 63 with TemporaryDirectory() as temp_dir: 64 print("Using temporary directory:", temp_dir) 65 os.chdir(temp_dir) 66 67 # Download and extract all the reports (both GHA and S3) 68 s3_paths = download_s3_artifacts( 69 "test-jsons", workflow_run_id, workflow_run_attempt 70 ) 71 for path in s3_paths: 72 unzip(path) 73 74 grouped_tests: dict[str, Any] = defaultdict(lambda: defaultdict(set)) 75 for td_exclusions in Path(".").glob("**/td_exclusions*.json"): 76 with open(td_exclusions) as f: 77 exclusions = json.load(f) 78 for exclusion in exclusions["excluded"]: 79 job_id = get_job_id(td_exclusions) 80 job_name = get_job_name(job_id) 81 build_name = get_build_name(job_name) 82 test_config = get_test_config(job_name) 83 grouped_tests[build_name][test_config].add(exclusion["test_file"]) 84 85 for build_name, build in grouped_tests.items(): 86 for test_config, test_files in build.items(): 87 grouped_tests[build_name][test_config] = sorted(test_files) 88 return grouped_tests 89 90 91def group_test_cases(test_cases: list[dict[str, Any]]) -> dict[str, Any]: 92 start = time.time() 93 grouped_tests: dict[str, Any] = defaultdict( 94 lambda: defaultdict( 95 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 96 ) 97 ) 98 for test_case in test_cases: 99 job_name = get_job_name(test_case["job_id"]) 100 build_name = get_build_name(job_name) 101 if "bazel" in build_name: 102 continue 103 test_config = get_test_config(job_name) 104 class_name = test_case.pop("classname", "NoClass") 105 name = test_case.pop("name", "NoName") 106 invoking_file = test_case.pop("invoking_file", "NoFile") 107 invoking_file = invoking_file.replace(".", "/") 108 test_case.pop("workflow_id") 109 test_case.pop("workflow_run_attempt") 110 grouped_tests[build_name][test_config][invoking_file][class_name][name].append( 111 test_case 112 ) 113 114 print(f"Time taken to group tests: {time.time() - start}") 115 return grouped_tests 116 117 118def get_reruns(grouped_tests: dict[str, Any]) -> dict[str, Any]: 119 reruns: dict[str, Any] = defaultdict( 120 lambda: defaultdict( 121 lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 122 ) 123 ) 124 for build_name, build in grouped_tests.items(): 125 for test_config, test_config_data in build.items(): 126 for invoking_file, invoking_file_data in test_config_data.items(): 127 for class_name, class_data in invoking_file_data.items(): 128 for test_name, test_data in class_data.items(): 129 if len(test_data) > 1: 130 if invoking_file in ( 131 "distributed/test_distributed_spawn", 132 "onnx/test_fx_to_onnx_with_onnxruntime", 133 "distributed/algorithms/quantization/test_quantization", 134 ): 135 continue 136 reruns[build_name][test_config][invoking_file][class_name][ 137 test_name 138 ] = test_data 139 return reruns 140 141 142def get_invoking_file_summary(grouped_tests: dict[str, Any]) -> dict[str, Any]: 143 invoking_file_summary: dict[str, Any] = defaultdict( 144 lambda: defaultdict(lambda: defaultdict(lambda: {"count": 0, "time": 0.0})) 145 ) 146 for build_name, build in grouped_tests.items(): 147 for test_config, test_config_data in build.items(): 148 for invoking_file, invoking_file_data in test_config_data.items(): 149 for class_data in invoking_file_data.values(): 150 for test_data in class_data.values(): 151 invoking_file_summary[build_name][test_config][invoking_file][ 152 "count" 153 ] += 1 154 for i in test_data: 155 invoking_file_summary[build_name][test_config][ 156 invoking_file 157 ]["time"] += i["time"] 158 159 return invoking_file_summary 160 161 162def upload_additional_info( 163 workflow_run_id: int, workflow_run_attempt: int, test_cases: list[dict[str, Any]] 164) -> None: 165 grouped_tests = group_test_cases(test_cases) 166 reruns = get_reruns(grouped_tests) 167 exclusions = get_td_exclusions(workflow_run_id, workflow_run_attempt) 168 invoking_file_summary = get_invoking_file_summary(grouped_tests) 169 170 upload_workflow_stats_to_s3( 171 workflow_run_id, 172 workflow_run_attempt, 173 "additional_info/reruns", 174 [reruns], 175 ) 176 upload_workflow_stats_to_s3( 177 workflow_run_id, 178 workflow_run_attempt, 179 "additional_info/td_exclusions", 180 [exclusions], 181 ) 182 upload_workflow_stats_to_s3( 183 workflow_run_id, 184 workflow_run_attempt, 185 "additional_info/invoking_file_summary", 186 [invoking_file_summary], 187 ) 188